openfree commited on
Commit
0398d1e
Β·
verified Β·
1 Parent(s): 63ef9ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +342 -317
app.py CHANGED
@@ -1,76 +1,100 @@
1
  import os
 
2
  import json
3
- import time
4
  from typing import List, Dict, Tuple
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import streamlit as st
7
  import requests
8
 
9
- # Guard imports for optional dependencies
10
  try:
11
  import torch
12
- from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM
13
  TORCH_AVAILABLE = True
14
- except Exception:
15
  TORCH_AVAILABLE = False
 
 
 
 
 
 
 
 
16
 
17
  try:
18
  from datasets import load_dataset
19
  DATASETS_AVAILABLE = True
20
- except Exception:
21
  DATASETS_AVAILABLE = False
 
22
 
23
  try:
24
  from sentence_transformers import SentenceTransformer
25
  SENTENCE_TRANSFORMERS_AVAILABLE = True
26
- except Exception:
27
  SENTENCE_TRANSFORMERS_AVAILABLE = False
 
28
 
29
  try:
30
  import faiss
31
  FAISS_AVAILABLE = True
32
- except Exception:
33
  FAISS_AVAILABLE = False
 
34
 
35
  try:
36
  from Bio import SeqIO
37
  BIOPYTHON_AVAILABLE = True
38
- except Exception:
39
  BIOPYTHON_AVAILABLE = False
 
40
 
41
- # Constants
42
  APP_TITLE = "BioSeq Chat: Protein & DNA Assistant"
43
- DISCLAIMER = (
44
- "This tool is for research/education and is not a medical device. "
45
- "Do not use outputs for diagnosis or treatment decisions."
46
- )
47
 
48
  # --------------- Helper Functions ---------------
49
 
50
  def get_secret(name: str, fallback: str = "") -> str:
51
- """Get secret from st.secrets, environment, or fallback"""
52
  try:
53
- if hasattr(st, 'secrets'):
54
- return st.secrets.get(name, os.environ.get(name, fallback))
 
55
  except:
56
  pass
 
57
  return os.environ.get(name, fallback)
58
 
59
  def brave_search(query: str, count: int = 5) -> List[Dict]:
60
- """Search using Brave Search API"""
61
  key = get_secret("BRAVE_API_KEY", "")
62
  if not key:
63
- return [{"title": "BRAVE_API_KEY is missing",
64
- "url": "",
65
- "snippet": "Set BRAVE_API_KEY in Space secrets or sidebar to enable web search."}]
 
 
66
 
67
  url = "https://api.search.brave.com/res/v1/web/search"
68
  headers = {
69
  "Accept": "application/json",
70
- "X-Subscription-Token": key,
71
- "Accept-Encoding": "gzip"
72
  }
73
- params = {"q": query, "count": count, "country": "us"}
74
 
75
  try:
76
  r = requests.get(url, headers=headers, params=params, timeout=15)
@@ -81,206 +105,198 @@ def brave_search(query: str, count: int = 5) -> List[Dict]:
81
  results.append({
82
  "title": item.get("title", ""),
83
  "url": item.get("url", ""),
84
- "snippet": item.get("description", ""),
85
  })
86
- return results if results else [{"title": "No results", "url": "", "snippet": "Query returned no results."}]
87
  except Exception as e:
88
- return [{"title": "Search error", "url": "", "snippet": str(e)}]
89
 
90
- def call_fireworks(messages: List[Dict], temperature: float = 0.6, max_tokens: int = 1024) -> str:
91
- """Call Fireworks AI chat completion API"""
92
  api_key = get_secret("FIREWORKS_API_KEY", "")
93
  if not api_key:
94
- return "FIREWORKS_API_KEY is missing. Set it in Secrets or the sidebar."
95
 
96
  url = "https://api.fireworks.ai/inference/v1/chat/completions"
97
  payload = {
98
  "model": "accounts/fireworks/models/llama-v3p1-70b-instruct",
 
99
  "max_tokens": max_tokens,
 
100
  "top_p": 1,
101
- "top_k": 40,
102
- "presence_penalty": 0,
103
  "frequency_penalty": 0,
104
- "temperature": temperature,
105
- "messages": messages
106
  }
107
  headers = {
108
- "Accept": "application/json",
109
  "Content-Type": "application/json",
110
  "Authorization": f"Bearer {api_key}"
111
  }
112
 
113
  try:
114
- r = requests.post(url, headers=headers, data=json.dumps(payload), timeout=60)
115
  r.raise_for_status()
116
- data = r.json()
117
- return data["choices"][0]["message"]["content"]
118
  except Exception as e:
119
- return f"[Fireworks API error] {e}"
120
 
121
- def load_text_from_file(upload) -> str:
122
  """Load text from uploaded file"""
123
  name = upload.name.lower()
124
- content = upload.read()
125
 
126
  try:
 
127
  text = content.decode("utf-8", errors="ignore")
128
  except:
129
- text = str(content)
130
 
131
- # FASTA file handling
132
  if name.endswith((".fa", ".fasta", ".faa", ".fna")) and BIOPYTHON_AVAILABLE:
133
- upload.seek(0)
134
  try:
 
135
  records = list(SeqIO.parse(upload, "fasta"))
136
- seqs = []
137
- for r in records:
138
- seqs.append(f">{r.id}\n{str(r.seq)}")
139
  return "\n\n".join(seqs)
140
  except:
141
  pass
142
 
143
  return text
144
 
145
- def build_vector_index(texts: List[str], embedder_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
146
- """Build FAISS vector index from texts"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  if not SENTENCE_TRANSFORMERS_AVAILABLE or not FAISS_AVAILABLE:
148
- return None, None, None
149
 
150
  try:
151
- model = SentenceTransformer(embedder_name)
152
- emb = model.encode(texts, show_progress_bar=False, normalize_embeddings=True)
153
- dim = emb.shape[1]
 
154
  index = faiss.IndexFlatIP(dim)
155
- index.add(emb.astype("float32"))
156
- return index, emb, model
 
157
  except Exception as e:
158
- st.warning(f"Failed to build index: {e}")
159
- return None, None, None
160
 
161
- def search_index(query: str, index, model, texts: List[str], k: int = 4):
162
  """Search vector index"""
163
  if index is None or model is None:
164
  return []
165
 
166
  try:
167
- q = model.encode([query], normalize_embeddings=True)
168
- D, I = index.search(q.astype("float32"), k)
169
- hits = []
 
170
  for idx, score in zip(I[0], D[0]):
171
  if 0 <= idx < len(texts):
172
- hits.append({"score": float(score), "text": texts[idx]})
173
- return hits
 
 
 
174
  except:
175
  return []
176
 
177
- def esm2_embed(seq: str, model_id: str = "facebook/esm2_t6_8M_UR50D") -> Dict:
178
- """Generate ESM-2 embedding for protein sequence"""
179
- if not TORCH_AVAILABLE:
180
- return {"error": "Transformers/torch not available. Please wait for dependencies to install."}
181
 
182
  try:
183
- from transformers import AutoTokenizer, AutoModelForMaskedLM
184
- import torch
185
-
186
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
187
- model = AutoModelForMaskedLM.from_pretrained(model_id, trust_remote_code=True)
188
  model.eval()
189
 
190
  with torch.no_grad():
191
- toks = tokenizer(seq, return_tensors="pt")
192
- out = model(**toks, output_hidden_states=True)
193
- hidden = out.hidden_states[-1].mean(dim=1).squeeze(0)
194
- vec = hidden.detach().cpu().numpy()
195
- return {"embedding": vec.tolist(), "hidden_size": vec.shape[0]}
 
 
 
 
196
  except Exception as e:
197
  return {"error": str(e)}
198
 
199
- def dna_embed(seq: str, model_id: str = "zhihan1996/DNABERT-2-117M") -> Dict:
200
- """Generate DNA embedding"""
201
- if not TORCH_AVAILABLE:
202
- return {"error": "Transformers/torch not available. Please wait for dependencies to install."}
203
 
204
  try:
205
- from transformers import AutoTokenizer, AutoModel
206
- import torch
207
-
208
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
209
- model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
210
  model.eval()
211
 
212
  with torch.no_grad():
213
- toks = tokenizer(seq, return_tensors="pt", truncation=True, max_length=4096)
214
- out = model(**toks, output_hidden_states=True)
215
- hidden = out.last_hidden_state.mean(dim=1).squeeze(0)
216
- vec = hidden.detach().cpu().numpy()
217
- return {"embedding": vec.tolist(), "hidden_size": vec.shape[0]}
 
 
 
 
218
  except Exception as e:
219
  return {"error": str(e)}
220
 
221
- def chunk_text(text: str, chunk_size: int = 1200, overlap: int = 200) -> List[str]:
222
- """Chunk text with overlap"""
223
- text = text.replace("\r\n", "\n")
224
- chunks = []
225
- start = 0
226
-
227
- while start < len(text):
228
- end = min(len(text), start + chunk_size)
229
- chunks.append(text[start:end])
230
- if end >= len(text):
231
- break
232
- start = end - overlap
233
-
234
- return chunks
235
-
236
- def build_context(user_query: str, index, index_model, docs: List[str], loaded_datasets: List, use_web: bool, web_k: int) -> Tuple[str, List[Dict]]:
237
- """Build context from various sources"""
238
  pieces = []
239
  sources = []
240
-
241
- # From uploaded files
242
- if index is not None and index_model is not None and docs:
243
- hits = search_index(user_query, index, index_model, docs, k=4)
244
- for h in hits:
245
- pieces.append(f"[FILE] {h['text'][:800]}")
246
- sources.append({"type": "file", "text": h["text"][:200]})
247
 
248
- # From datasets
249
- for rid, sample in loaded_datasets:
250
- if sample:
251
- pieces.append(f"[DATASET {rid}] {sample}")
252
- sources.append({"type": "dataset", "id": rid})
 
253
 
254
- # From web
255
  if use_web:
256
- results = brave_search(user_query, count=web_k)
257
  for r in results:
258
- snippet = r.get("snippet", "")
259
- url = r.get("url", "")
260
- title = r.get("title", "")
261
- pieces.append(f"[WEB] {title}\n{snippet}\n{url}")
262
- sources.append({"type": "web", "title": title, "url": url})
263
 
264
- context = "\n\n---\n\n".join(pieces)[:6000]
265
  return context, sources
266
 
267
- def chat_answer(user_query: str, index, index_model, docs: List[str], loaded_datasets: List, use_web: bool, web_k: int) -> Tuple[str, List[Dict]]:
268
- """Generate chat answer with context"""
269
- context, sources = build_context(user_query, index, index_model, docs, loaded_datasets, use_web, web_k)
270
  system = (
271
- "You are a concise, careful bioinformatics assistant for protein and DNA. "
272
- "Answer with factual, verifiable statements. "
273
- "When uncertain, say so briefly. "
274
- "Never give medical advice. Provide short references as plain URLs or titles if present in context. "
275
- "User uploads and web/dataset snippets are provided as context below."
276
  )
277
- prompt = f"Context:\n{context}\n\nUser question:\n{user_query}\n\nAnswer in Korean if the user used Korean; otherwise match user language."
 
 
278
  messages = [
279
  {"role": "system", "content": system},
280
- {"role": "user", "content": prompt}
281
  ]
282
- answer = call_fireworks(messages, temperature=0.4, max_tokens=1200)
283
- return answer, sources
284
 
285
  # --------------- Streamlit UI ---------------
286
 
@@ -288,215 +304,224 @@ st.set_page_config(page_title=APP_TITLE, page_icon="🧬", layout="wide")
288
  st.title(APP_TITLE)
289
  st.caption(DISCLAIMER)
290
 
291
- # Check dependencies status
292
- if not TORCH_AVAILABLE:
293
- st.warning("⏳ PyTorch is being installed. Some features may be unavailable initially. Please refresh in a minute.")
294
-
295
- # Initialize session state
296
- if 'docs' not in st.session_state:
297
  st.session_state.docs = []
298
- if 'index' not in st.session_state:
299
  st.session_state.index = None
300
- if 'index_model' not in st.session_state:
301
- st.session_state.index_model = None
302
- if 'loaded_datasets' not in st.session_state:
303
- st.session_state.loaded_datasets = []
304
 
305
- # Sidebar configuration
306
  with st.sidebar:
307
- st.header("Keys and settings")
308
- fw_key = st.text_input("FIREWORKS_API_KEY", value=get_secret("FIREWORKS_API_KEY", ""), type="password")
309
- brave_key = st.text_input("BRAVE_API_KEY", value=get_secret("BRAVE_API_KEY", ""), type="password")
 
 
 
 
 
 
 
 
 
310
 
311
  if fw_key:
312
  os.environ["FIREWORKS_API_KEY"] = fw_key
313
  if brave_key:
314
  os.environ["BRAVE_API_KEY"] = brave_key
315
 
316
- st.markdown("### Model selections")
317
- esm2_id = st.text_input(
318
- "Protein model (ESM-2)",
319
- value="facebook/esm2_t6_8M_UR50D",
320
- help="Try larger models like facebook/esm2_t33_650M_UR50D if resources allow."
321
- )
322
- dna_id = st.text_input(
323
- "DNA model",
324
- value="zhihan1996/DNABERT-2-117M",
325
- help="Alternative: InstaDeepAI/nucleotide-transformer-500m-human-ref"
326
- )
327
-
328
- use_web = st.checkbox("Use Brave web search for context", value=True)
329
- web_k = st.slider("Web results", 1, 10, 4)
330
 
331
- st.markdown("### Datasets (optional)")
332
- dataset_ids = st.text_area(
333
- "Datasets to load (one per line)",
334
- value="",
335
- help="Enter Hugging Face dataset repo ids, e.g., 'genomics-benchmark/jaspar_motifs'"
 
 
336
  )
337
 
338
- st.divider()
339
- st.markdown("Files you upload are indexed locally and used for answers.")
340
-
341
- # Main tabs
342
- tabs = st.tabs(["Chat", "Protein", "DNA", "Examples", "About"])
343
-
344
- # File upload section
345
- with st.expander("Upload files for context (txt/csv/json/fasta/vcf)", expanded=True):
346
- uploads = st.file_uploader(
347
- "Add files",
348
- type=["txt", "md", "csv", "tsv", "json", "fa", "fasta", "faa", "fna", "vcf"],
349
- accept_multiple_files=True,
350
- key="file_uploader"
351
  )
352
 
353
- if uploads:
354
  docs = []
355
- for up in uploads:
356
  try:
357
- txt = load_text_from_file(up)
358
- docs.extend(chunk_text(txt))
 
359
  except Exception as e:
360
- st.warning(f"Failed to read {up.name}: {e}")
361
-
362
- st.session_state.docs = docs
363
- st.caption(f"Indexed chunks: {len(docs)}")
364
 
365
- # Build index if docs available
366
- if docs and SENTENCE_TRANSFORMERS_AVAILABLE and FAISS_AVAILABLE:
367
- with st.spinner("Building vector index..."):
368
- index, emb, index_model = build_vector_index(docs)
369
- st.session_state.index = index
370
- st.session_state.index_model = index_model
371
- else:
372
- st.caption("No files uploaded yet")
373
-
374
- # Load datasets if specified
375
- if dataset_ids.strip() and DATASETS_AVAILABLE:
376
- dataset_list = [x.strip() for x in dataset_ids.splitlines() if x.strip()]
377
- if dataset_list != [d[0] for d in st.session_state.loaded_datasets]:
378
- st.session_state.loaded_datasets = []
379
- for rid in dataset_list:
380
- with st.spinner(f"Loading dataset {rid}..."):
381
- try:
382
- ds = load_dataset(rid)
383
- sample = ""
384
- for split in ds.keys():
385
- try:
386
- row = ds[split][0]
387
- sample = json.dumps(row, ensure_ascii=False)[:500]
388
- break
389
- except:
390
- pass
391
- st.session_state.loaded_datasets.append((rid, sample))
392
- st.success(f"Loaded {rid}")
393
- except Exception as e:
394
- st.error(f"Failed to load {rid}: {e}")
395
 
396
  # Chat tab
397
- with tabs[0]:
398
- st.subheader("Chat")
399
- q = st.text_area("Ask a question about protein/DNA", value="ESM-2 μž„λ² λ”©μ€ λ‹¨λ°±μ§ˆ κΈ°λŠ₯ 해석에 μ–΄λ–»κ²Œ λ„μ›€λ˜λ‚˜μš”?")
400
-
401
- if st.button("Answer", type="primary"):
402
- with st.spinner("Thinking..."):
403
- ans, srcs = chat_answer(
404
- q,
405
- st.session_state.index,
406
- st.session_state.index_model,
407
- st.session_state.docs,
408
- st.session_state.loaded_datasets,
409
- use_web,
410
- web_k
411
- )
412
- st.write(ans)
413
-
414
- if srcs:
415
- st.markdown("#### Sources")
416
- for s in srcs:
417
- if s.get("type") == "web" and s.get("url"):
418
- st.markdown(f"- {s.get('title', 'web')}: {s.get('url')}")
419
- elif s.get("type") == "dataset":
420
- st.markdown(f"- dataset: {s.get('id')}")
421
- elif s.get("type") == "file":
422
- snippet = s.get("text", "")
423
- st.markdown(f"- file snippet: {snippet[:120]}...")
 
 
 
 
 
 
 
 
424
 
425
  # Protein tab
426
- with tabs[1]:
427
- st.subheader("Protein analysis")
428
- seq = st.text_area("Protein sequence (amino acids only)", value="MKTIIALSYIFCLVFADYKDDDDK")
 
 
 
 
 
429
 
430
  col1, col2 = st.columns(2)
 
431
  with col1:
432
- st.caption("ESM-2 embedding")
433
- if st.button("Run ESM-2", key="run_esm2"):
434
- with st.spinner("Computing ESM-2 embedding..."):
435
- out = esm2_embed(seq.strip(), esm2_id)
436
- if "error" in out:
437
- st.error(out["error"])
 
 
 
 
 
 
 
 
 
 
438
  else:
439
- st.success(f"Vector size: {out['hidden_size']}")
440
- st.json({"embedding_preview": out["embedding"][:8]})
441
 
442
  with col2:
443
- st.caption("Quick stats")
444
- s = seq.replace("\n", "").replace(" ", "").upper()
445
- length = len(s)
446
- aa_set = sorted(set(list(s)))
447
- st.write(f"Length: {length}")
448
- st.write(f"Unique AAs: {''.join(aa_set)[:30]}")
449
 
450
  # DNA tab
451
- with tabs[2]:
452
- st.subheader("DNA analysis")
453
- dseq = st.text_area("DNA sequence (ACGT only)", value="ATGCGTACGTAGCTAGCTAGCTAGGCTAGC")
454
-
455
- col3, col4 = st.columns(2)
456
- with col3:
457
- st.caption("DNA embedding")
458
- if st.button("Run DNA embed", key="run_dna"):
459
- with st.spinner("Computing DNA embedding..."):
460
- out = dna_embed(dseq.strip(), dna_id)
461
- if "error" in out:
462
- st.error(out["error"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
  else:
464
- st.success(f"Vector size: {out['hidden_size']}")
465
- st.json({"embedding_preview": out["embedding"][:8]})
466
-
467
- with col4:
468
- st.caption("GC content")
469
- s = dseq.upper().replace("N", "").replace(" ", "").replace("\n", "")
470
- if len(s) > 0:
471
- gc = (s.count("G") + s.count("C")) / len(s)
472
- else:
473
- gc = 0
474
- st.write(f"Length: {len(s)}")
475
- st.write(f"GC: {gc:.3f}")
476
-
477
- # Examples tab
478
- with tabs[3]:
479
- st.subheader("Examples")
480
- st.markdown("### Example questions you can ask:")
481
- st.markdown("- μ—…λ‘œλ“œν•œ FASTAμ—μ„œ νŠΉμ • λ‹¨λ°±μ§ˆμ˜ κΈ°λŠ₯ μš”μ•½κ³Ό 변이 영ν–₯ 질문")
482
- st.markdown("- DNA μ„œμ—΄μ—μ„œ ν”„λ‘œλͺ¨ν„° κ°€λŠ₯μ„±κ³Ό μ „μ‚¬μΈμž λͺ¨ν‹°ν”„ κ΄€λ ¨ κ·Όκ±° μš”μ²­")
483
- st.markdown("- Enzyme active site κ·Όμ ‘ λ³€μ΄μ˜ 리슀크 해석 (연ꡬ 관점)")
484
- st.markdown("- ENCODE/UniProt/AlphaFold κ°œλ… μ„€λͺ… μš”μ²­")
485
- st.markdown("- RAG 기반으둜 λ¬Έμ„œ 인용과 ν•¨κ»˜ κ°„λž΅ λ‹΅λ³€ μš”μ²­")
486
 
487
  # About tab
488
- with tabs[4]:
489
- st.subheader("About this Space")
490
- st.write("**Models suggested:**")
491
- st.write("- ESM-2 for proteins")
492
- st.write("- DNABERT-2 or Nucleotide Transformer for DNA")
493
- st.write("")
494
- st.write("**Common datasets:**")
495
- st.write("- UniProtKB, AlphaFoldDB, ENCODE, JASPAR, ClinVar")
496
- st.write("")
497
- st.write("**Features:**")
498
- st.write("- Web search powered by Brave Search API")
499
- st.write("- LLM powered by Fireworks AI")
500
- st.write("- Vector search with FAISS")
501
- st.write("")
502
- st.info(DISCLAIMER)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import sys
3
  import json
 
4
  from typing import List, Dict, Tuple
5
 
6
+ # Streamlit μ‹€ν–‰ 확인
7
+ def _running_in_streamlit() -> bool:
8
+ try:
9
+ from streamlit.runtime.scriptrunner import get_script_run_ctx
10
+ return get_script_run_ctx() is not None
11
+ except Exception:
12
+ return False
13
+
14
+ if not _running_in_streamlit():
15
+ print("이 앱은 Streamlit μ„œλ²„λ‘œ μ‹€ν–‰ν•΄μ•Ό ν•©λ‹ˆλ‹€.")
16
+ print("λͺ…λ Ή: streamlit run app.py --server.port=8501 --server.address=0.0.0.0")
17
+ sys.exit(0)
18
+
19
  import streamlit as st
20
  import requests
21
 
22
+ # 선택적 μ˜μ‘΄μ„± κ°€λ“œ
23
  try:
24
  import torch
 
25
  TORCH_AVAILABLE = True
26
+ except ImportError:
27
  TORCH_AVAILABLE = False
28
+ print("[WARNING] torch not available")
29
+
30
+ try:
31
+ from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM
32
+ TRANSFORMERS_AVAILABLE = True
33
+ except ImportError:
34
+ TRANSFORMERS_AVAILABLE = False
35
+ print("[WARNING] transformers not available")
36
 
37
  try:
38
  from datasets import load_dataset
39
  DATASETS_AVAILABLE = True
40
+ except ImportError:
41
  DATASETS_AVAILABLE = False
42
+ print("[WARNING] datasets not available")
43
 
44
  try:
45
  from sentence_transformers import SentenceTransformer
46
  SENTENCE_TRANSFORMERS_AVAILABLE = True
47
+ except ImportError:
48
  SENTENCE_TRANSFORMERS_AVAILABLE = False
49
+ print("[WARNING] sentence_transformers not available")
50
 
51
  try:
52
  import faiss
53
  FAISS_AVAILABLE = True
54
+ except ImportError:
55
  FAISS_AVAILABLE = False
56
+ print("[WARNING] faiss not available")
57
 
58
  try:
59
  from Bio import SeqIO
60
  BIOPYTHON_AVAILABLE = True
61
+ except ImportError:
62
  BIOPYTHON_AVAILABLE = False
63
+ print("[WARNING] biopython not available")
64
 
65
+ # μƒμˆ˜
66
  APP_TITLE = "BioSeq Chat: Protein & DNA Assistant"
67
+ DISCLAIMER = "This tool is for research/education and is not a medical device. Do not use outputs for diagnosis or treatment decisions."
 
 
 
68
 
69
  # --------------- Helper Functions ---------------
70
 
71
  def get_secret(name: str, fallback: str = "") -> str:
72
+ """Get secret from st.secrets or environment"""
73
  try:
74
+ # Streamlit secrets
75
+ if hasattr(st, 'secrets') and name in st.secrets:
76
+ return st.secrets[name]
77
  except:
78
  pass
79
+ # Environment variable
80
  return os.environ.get(name, fallback)
81
 
82
  def brave_search(query: str, count: int = 5) -> List[Dict]:
83
+ """Brave Search API"""
84
  key = get_secret("BRAVE_API_KEY", "")
85
  if not key:
86
+ return [{
87
+ "title": "BRAVE_API_KEY missing",
88
+ "url": "",
89
+ "snippet": "Set BRAVE_API_KEY in Space secrets or sidebar"
90
+ }]
91
 
92
  url = "https://api.search.brave.com/res/v1/web/search"
93
  headers = {
94
  "Accept": "application/json",
95
+ "X-Subscription-Token": key
 
96
  }
97
+ params = {"q": query, "count": count}
98
 
99
  try:
100
  r = requests.get(url, headers=headers, params=params, timeout=15)
 
105
  results.append({
106
  "title": item.get("title", ""),
107
  "url": item.get("url", ""),
108
+ "snippet": item.get("description", "")
109
  })
110
+ return results if results else [{"title": "No results", "url": "", "snippet": ""}]
111
  except Exception as e:
112
+ return [{"title": "Error", "url": "", "snippet": str(e)}]
113
 
114
+ def call_llm(messages: List[Dict], temperature: float = 0.6, max_tokens: int = 1024) -> str:
115
+ """Call Fireworks AI API"""
116
  api_key = get_secret("FIREWORKS_API_KEY", "")
117
  if not api_key:
118
+ return "FIREWORKS_API_KEY missing. Set it in Secrets or sidebar."
119
 
120
  url = "https://api.fireworks.ai/inference/v1/chat/completions"
121
  payload = {
122
  "model": "accounts/fireworks/models/llama-v3p1-70b-instruct",
123
+ "messages": messages,
124
  "max_tokens": max_tokens,
125
+ "temperature": temperature,
126
  "top_p": 1,
 
 
127
  "frequency_penalty": 0,
128
+ "presence_penalty": 0
 
129
  }
130
  headers = {
 
131
  "Content-Type": "application/json",
132
  "Authorization": f"Bearer {api_key}"
133
  }
134
 
135
  try:
136
+ r = requests.post(url, headers=headers, json=payload, timeout=60)
137
  r.raise_for_status()
138
+ return r.json()["choices"][0]["message"]["content"]
 
139
  except Exception as e:
140
+ return f"[LLM Error] {e}"
141
 
142
+ def load_file_text(upload) -> str:
143
  """Load text from uploaded file"""
144
  name = upload.name.lower()
 
145
 
146
  try:
147
+ content = upload.read()
148
  text = content.decode("utf-8", errors="ignore")
149
  except:
150
+ return ""
151
 
152
+ # FASTA handling
153
  if name.endswith((".fa", ".fasta", ".faa", ".fna")) and BIOPYTHON_AVAILABLE:
 
154
  try:
155
+ upload.seek(0)
156
  records = list(SeqIO.parse(upload, "fasta"))
157
+ seqs = [f">{r.id}\n{str(r.seq)}" for r in records]
 
 
158
  return "\n\n".join(seqs)
159
  except:
160
  pass
161
 
162
  return text
163
 
164
+ def chunk_text(text: str, size: int = 1200, overlap: int = 200) -> List[str]:
165
+ """Split text into chunks"""
166
+ chunks = []
167
+ start = 0
168
+ text_len = len(text)
169
+
170
+ while start < text_len:
171
+ end = min(start + size, text_len)
172
+ chunks.append(text[start:end])
173
+ if end >= text_len:
174
+ break
175
+ start = end - overlap
176
+
177
+ return chunks
178
+
179
+ def build_index(texts: List[str]):
180
+ """Build vector index"""
181
  if not SENTENCE_TRANSFORMERS_AVAILABLE or not FAISS_AVAILABLE:
182
+ return None, None
183
 
184
  try:
185
+ model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
186
+ embeddings = model.encode(texts, show_progress_bar=False)
187
+
188
+ dim = embeddings.shape[1]
189
  index = faiss.IndexFlatIP(dim)
190
+ index.add(embeddings.astype("float32"))
191
+
192
+ return index, model
193
  except Exception as e:
194
+ st.warning(f"Index build failed: {e}")
195
+ return None, None
196
 
197
+ def search_index(query: str, index, model, texts: List[str], k: int = 4) -> List[Dict]:
198
  """Search vector index"""
199
  if index is None or model is None:
200
  return []
201
 
202
  try:
203
+ q_emb = model.encode([query])
204
+ D, I = index.search(q_emb.astype("float32"), k)
205
+
206
+ results = []
207
  for idx, score in zip(I[0], D[0]):
208
  if 0 <= idx < len(texts):
209
+ results.append({
210
+ "score": float(score),
211
+ "text": texts[idx]
212
+ })
213
+ return results
214
  except:
215
  return []
216
 
217
+ def esm2_embed(seq: str, model_name: str = "facebook/esm2_t6_8M_UR50D") -> Dict:
218
+ """ESM-2 protein embedding"""
219
+ if not TORCH_AVAILABLE or not TRANSFORMERS_AVAILABLE:
220
+ return {"error": "PyTorch/Transformers not available"}
221
 
222
  try:
223
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
224
+ model = AutoModelForMaskedLM.from_pretrained(model_name)
 
 
 
225
  model.eval()
226
 
227
  with torch.no_grad():
228
+ inputs = tokenizer(seq, return_tensors="pt")
229
+ outputs = model(**inputs, output_hidden_states=True)
230
+ hidden = outputs.hidden_states[-1].mean(dim=1).squeeze(0)
231
+ vec = hidden.numpy()
232
+
233
+ return {
234
+ "embedding": vec.tolist(),
235
+ "size": vec.shape[0]
236
+ }
237
  except Exception as e:
238
  return {"error": str(e)}
239
 
240
+ def dna_embed(seq: str, model_name: str = "zhihan1996/DNABERT-2-117M") -> Dict:
241
+ """DNA embedding"""
242
+ if not TORCH_AVAILABLE or not TRANSFORMERS_AVAILABLE:
243
+ return {"error": "PyTorch/Transformers not available"}
244
 
245
  try:
246
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
247
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
 
 
 
248
  model.eval()
249
 
250
  with torch.no_grad():
251
+ inputs = tokenizer(seq, return_tensors="pt", truncation=True, max_length=512)
252
+ outputs = model(**inputs)
253
+ hidden = outputs.last_hidden_state.mean(dim=1).squeeze(0)
254
+ vec = hidden.numpy()
255
+
256
+ return {
257
+ "embedding": vec.tolist(),
258
+ "size": vec.shape[0]
259
+ }
260
  except Exception as e:
261
  return {"error": str(e)}
262
 
263
+ def build_context(query: str, docs: List[str], index, model, use_web: bool, web_k: int) -> Tuple[str, List[Dict]]:
264
+ """Build context from sources"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  pieces = []
266
  sources = []
 
 
 
 
 
 
 
267
 
268
+ # File search
269
+ if index and model and docs:
270
+ hits = search_index(query, index, model, docs, k=4)
271
+ for h in hits:
272
+ pieces.append(f"[FILE] {h['text'][:500]}")
273
+ sources.append({"type": "file", "text": h['text'][:100]})
274
 
275
+ # Web search
276
  if use_web:
277
+ results = brave_search(query, count=web_k)
278
  for r in results:
279
+ pieces.append(f"[WEB] {r['title']}\n{r['snippet']}")
280
+ sources.append({"type": "web", "title": r['title'], "url": r['url']})
 
 
 
281
 
282
+ context = "\n\n---\n\n".join(pieces)[:4000]
283
  return context, sources
284
 
285
+ def answer_question(query: str, context: str) -> str:
286
+ """Generate answer"""
 
287
  system = (
288
+ "You are a bioinformatics assistant. Be concise and factual. "
289
+ "Never give medical advice. Answer in the user's language."
 
 
 
290
  )
291
+
292
+ user_msg = f"Context:\n{context}\n\nQuestion: {query}"
293
+
294
  messages = [
295
  {"role": "system", "content": system},
296
+ {"role": "user", "content": user_msg}
297
  ]
298
+
299
+ return call_llm(messages, temperature=0.4, max_tokens=1000)
300
 
301
  # --------------- Streamlit UI ---------------
302
 
 
304
  st.title(APP_TITLE)
305
  st.caption(DISCLAIMER)
306
 
307
+ # Session state init
308
+ if "docs" not in st.session_state:
 
 
 
 
309
  st.session_state.docs = []
310
+ if "index" not in st.session_state:
311
  st.session_state.index = None
312
+ if "model" not in st.session_state:
313
+ st.session_state.model = None
 
 
314
 
315
+ # Sidebar
316
  with st.sidebar:
317
+ st.header("Configuration")
318
+
319
+ fw_key = st.text_input(
320
+ "FIREWORKS_API_KEY",
321
+ value=get_secret("FIREWORKS_API_KEY", ""),
322
+ type="password"
323
+ )
324
+ brave_key = st.text_input(
325
+ "BRAVE_API_KEY",
326
+ value=get_secret("BRAVE_API_KEY", ""),
327
+ type="password"
328
+ )
329
 
330
  if fw_key:
331
  os.environ["FIREWORKS_API_KEY"] = fw_key
332
  if brave_key:
333
  os.environ["BRAVE_API_KEY"] = brave_key
334
 
335
+ st.divider()
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
+ esm_model = st.text_input(
338
+ "ESM-2 Model",
339
+ value="facebook/esm2_t6_8M_UR50D"
340
+ )
341
+ dna_model = st.text_input(
342
+ "DNA Model",
343
+ value="zhihan1996/DNABERT-2-117M"
344
  )
345
 
346
+ use_web = st.checkbox("Enable web search", value=True)
347
+ web_results = st.slider("Web results", 1, 10, 3)
348
+
349
+ # Tabs
350
+ tab1, tab2, tab3, tab4 = st.tabs(["Chat", "Protein", "DNA", "About"])
351
+
352
+ # File upload
353
+ with st.expander("πŸ“ Upload Files", expanded=True):
354
+ files = st.file_uploader(
355
+ "Upload text/FASTA files",
356
+ type=["txt", "fa", "fasta", "csv", "json"],
357
+ accept_multiple_files=True
 
358
  )
359
 
360
+ if files:
361
  docs = []
362
+ for f in files:
363
  try:
364
+ text = load_file_text(f)
365
+ if text:
366
+ docs.extend(chunk_text(text))
367
  except Exception as e:
368
+ st.error(f"Error reading {f.name}: {e}")
 
 
 
369
 
370
+ if docs:
371
+ st.session_state.docs = docs
372
+ st.success(f"Loaded {len(docs)} chunks")
373
+
374
+ if SENTENCE_TRANSFORMERS_AVAILABLE and FAISS_AVAILABLE:
375
+ with st.spinner("Building index..."):
376
+ index, model = build_index(docs)
377
+ if index:
378
+ st.session_state.index = index
379
+ st.session_state.model = model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
 
381
  # Chat tab
382
+ with tab1:
383
+ st.subheader("πŸ’¬ Chat Assistant")
384
+
385
+ question = st.text_area(
386
+ "Ask about proteins, DNA, or bioinformatics:",
387
+ value="What is the role of ESM-2 embeddings in protein analysis?",
388
+ height=100
389
+ )
390
+
391
+ if st.button("Get Answer", type="primary"):
392
+ if not get_secret("FIREWORKS_API_KEY"):
393
+ st.error("Please set FIREWORKS_API_KEY")
394
+ else:
395
+ with st.spinner("Thinking..."):
396
+ context, sources = build_context(
397
+ question,
398
+ st.session_state.docs,
399
+ st.session_state.index,
400
+ st.session_state.model,
401
+ use_web,
402
+ web_results
403
+ )
404
+
405
+ answer = answer_question(question, context)
406
+
407
+ st.markdown("### Answer")
408
+ st.write(answer)
409
+
410
+ if sources:
411
+ st.markdown("### Sources")
412
+ for s in sources:
413
+ if s["type"] == "web":
414
+ st.write(f"- 🌐 [{s['title']}]({s['url']})")
415
+ elif s["type"] == "file":
416
+ st.write(f"- πŸ“„ File: {s['text'][:80]}...")
417
 
418
  # Protein tab
419
+ with tab2:
420
+ st.subheader("🧬 Protein Analysis")
421
+
422
+ protein_seq = st.text_area(
423
+ "Enter protein sequence:",
424
+ value="MKTIIALSYIFCLVFA",
425
+ height=100
426
+ )
427
 
428
  col1, col2 = st.columns(2)
429
+
430
  with col1:
431
+ if st.button("Analyze Protein"):
432
+ seq = protein_seq.strip().upper()
433
+
434
+ # Basic stats
435
+ st.write(f"**Length:** {len(seq)}")
436
+ st.write(f"**Unique AAs:** {len(set(seq))}")
437
+
438
+ # ESM-2 embedding
439
+ if TORCH_AVAILABLE and TRANSFORMERS_AVAILABLE:
440
+ with st.spinner("Computing embedding..."):
441
+ result = esm2_embed(seq, esm_model)
442
+ if "error" in result:
443
+ st.error(result["error"])
444
+ else:
445
+ st.success(f"Embedding size: {result['size']}")
446
+ st.json({"preview": result["embedding"][:5]})
447
  else:
448
+ st.warning("PyTorch not available for embeddings")
 
449
 
450
  with col2:
451
+ st.info("Amino acid composition and structure prediction features coming soon")
 
 
 
 
 
452
 
453
  # DNA tab
454
+ with tab3:
455
+ st.subheader("🧬 DNA Analysis")
456
+
457
+ dna_seq = st.text_area(
458
+ "Enter DNA sequence:",
459
+ value="ATGCGATCGTAGC",
460
+ height=100
461
+ )
462
+
463
+ col1, col2 = st.columns(2)
464
+
465
+ with col1:
466
+ if st.button("Analyze DNA"):
467
+ seq = dna_seq.strip().upper()
468
+
469
+ # GC content
470
+ gc = (seq.count("G") + seq.count("C")) / len(seq) if seq else 0
471
+
472
+ st.write(f"**Length:** {len(seq)}")
473
+ st.write(f"**GC Content:** {gc:.2%}")
474
+
475
+ # DNA embedding
476
+ if TORCH_AVAILABLE and TRANSFORMERS_AVAILABLE:
477
+ with st.spinner("Computing embedding..."):
478
+ result = dna_embed(seq, dna_model)
479
+ if "error" in result:
480
+ st.error(result["error"])
481
+ else:
482
+ st.success(f"Embedding size: {result['size']}")
483
+ st.json({"preview": result["embedding"][:5]})
484
  else:
485
+ st.warning("PyTorch not available for embeddings")
486
+
487
+ with col2:
488
+ st.info("Motif analysis and structure prediction coming soon")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
 
490
  # About tab
491
+ with tab4:
492
+ st.subheader("ℹ️ About")
493
+ st.markdown("""
494
+ ### Features
495
+ - πŸ’¬ RAG-based chat for bioinformatics questions
496
+ - 🧬 Protein sequence analysis with ESM-2
497
+ - 🧬 DNA sequence analysis with DNABERT-2
498
+ - πŸ” Web search integration via Brave API
499
+ - πŸ“ File upload and vector search
500
+
501
+ ### Models
502
+ - **Proteins:** ESM-2 (Facebook)
503
+ - **DNA:** DNABERT-2 (Microsoft)
504
+ - **LLM:** Llama 3.1 70B (via Fireworks)
505
+
506
+ ### Disclaimer
507
+ This tool is for research and educational purposes only.
508
+ Not for medical diagnosis or treatment decisions.
509
+ """)
510
+
511
+ # Dependency check
512
+ st.divider()
513
+ st.subheader("System Status")
514
+ deps = {
515
+ "PyTorch": TORCH_AVAILABLE,
516
+ "Transformers": TRANSFORMERS_AVAILABLE,
517
+ "Sentence Transformers": SENTENCE_TRANSFORMERS_AVAILABLE,
518
+ "FAISS": FAISS_AVAILABLE,
519
+ "BioPython": BIOPYTHON_AVAILABLE,
520
+ "Datasets": DATASETS_AVAILABLE
521
+ }
522
+
523
+ for name, available in deps.items():
524
+ if available:
525
+ st.success(f"βœ… {name}")
526
+ else:
527
+ st.warning(f"⚠️ {name} not available")