openfree commited on
Commit
29ce347
·
verified ·
1 Parent(s): cb1dc3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +204 -209
app.py CHANGED
@@ -95,7 +95,7 @@ def call_fireworks(messages: List[Dict], temperature: float = 0.6, max_tokens: i
95
 
96
  url = "https://api.fireworks.ai/inference/v1/chat/completions"
97
  payload = {
98
- "model": "accounts/fireworks/models/qwen3-235b-a22b-instruct-2507",
99
  "max_tokens": max_tokens,
100
  "top_p": 1,
101
  "top_k": 40,
@@ -282,226 +282,221 @@ def chat_answer(user_query: str, index, index_model, docs: List[str], loaded_dat
282
  answer = call_fireworks(messages, temperature=0.4, max_tokens=1200)
283
  return answer, sources
284
 
285
- # --------------- Main Application ---------------
286
 
287
- def main():
288
- st.set_page_config(page_title=APP_TITLE, page_icon="🧬", layout="wide")
289
- st.title(APP_TITLE)
290
- st.caption(DISCLAIMER)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
- # Check dependencies status
293
- if not TORCH_AVAILABLE:
294
- st.warning("⏳ PyTorch is being installed. Some features may be unavailable initially. Please refresh in a minute.")
 
295
 
296
- # Initialize session state
297
- if 'docs' not in st.session_state:
298
- st.session_state.docs = []
299
- if 'index' not in st.session_state:
300
- st.session_state.index = None
301
- if 'index_model' not in st.session_state:
302
- st.session_state.index_model = None
303
- if 'loaded_datasets' not in st.session_state:
304
- st.session_state.loaded_datasets = []
 
 
305
 
306
- # Sidebar configuration
307
- with st.sidebar:
308
- st.header("Keys and settings")
309
- fw_key = st.text_input("FIREWORKS_API_KEY", value=get_secret("FIREWORKS_API_KEY", ""), type="password")
310
- brave_key = st.text_input("BRAVE_API_KEY", value=get_secret("BRAVE_API_KEY", ""), type="password")
311
-
312
- if fw_key:
313
- os.environ["FIREWORKS_API_KEY"] = fw_key
314
- if brave_key:
315
- os.environ["BRAVE_API_KEY"] = brave_key
316
-
317
- st.markdown("### Model selections")
318
- esm2_id = st.text_input(
319
- "Protein model (ESM-2)",
320
- value="facebook/esm2_t6_8M_UR50D",
321
- help="Try larger models like facebook/esm2_t33_650M_UR50D if resources allow."
322
- )
323
- dna_id = st.text_input(
324
- "DNA model",
325
- value="zhihan1996/DNABERT-2-117M",
326
- help="Alternative: InstaDeepAI/nucleotide-transformer-500m-human-ref"
327
- )
328
-
329
- use_web = st.checkbox("Use Brave web search for context", value=True)
330
- web_k = st.slider("Web results", 1, 10, 4)
331
-
332
- st.markdown("### Datasets (optional)")
333
- dataset_ids = st.text_area(
334
- "Datasets to load (one per line)",
335
- value="",
336
- help="Enter Hugging Face dataset repo ids, e.g., 'genomics-benchmark/jaspar_motifs'"
337
- )
338
-
339
- st.divider()
340
- st.markdown("Files you upload are indexed locally and used for answers.")
341
 
342
- # Main tabs
343
- tabs = st.tabs(["Chat", "Protein", "DNA", "Examples", "About"])
 
 
 
 
344
 
345
- # File upload section
346
- with st.expander("Upload files for context (txt/csv/json/fasta/vcf)", expanded=True):
347
- uploads = st.file_uploader(
348
- "Add files",
349
- type=["txt", "md", "csv", "tsv", "json", "fa", "fasta", "faa", "fna", "vcf"],
350
- accept_multiple_files=True,
351
- key="file_uploader"
352
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
 
354
- if uploads:
355
- docs = []
356
- for up in uploads:
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  try:
358
- txt = load_text_from_file(up)
359
- docs.extend(chunk_text(txt))
 
 
 
 
 
 
 
 
 
360
  except Exception as e:
361
- st.warning(f"Failed to read {up.name}: {e}")
362
-
363
- st.session_state.docs = docs
364
- st.caption(f"Indexed chunks: {len(docs)}")
365
-
366
- # Build index if docs available
367
- if docs and SENTENCE_TRANSFORMERS_AVAILABLE and FAISS_AVAILABLE:
368
- with st.spinner("Building vector index..."):
369
- index, emb, index_model = build_vector_index(docs)
370
- st.session_state.index = index
371
- st.session_state.index_model = index_model
372
- else:
373
- st.caption("No files uploaded yet")
374
-
375
- # Load datasets if specified
376
- if dataset_ids.strip() and DATASETS_AVAILABLE:
377
- dataset_list = [x.strip() for x in dataset_ids.splitlines() if x.strip()]
378
- if dataset_list != [d[0] for d in st.session_state.loaded_datasets]:
379
- st.session_state.loaded_datasets = []
380
- for rid in dataset_list:
381
- with st.spinner(f"Loading dataset {rid}..."):
382
- try:
383
- ds = load_dataset(rid)
384
- sample = ""
385
- for split in ds.keys():
386
- try:
387
- row = ds[split][0]
388
- sample = json.dumps(row, ensure_ascii=False)[:500]
389
- break
390
- except:
391
- pass
392
- st.session_state.loaded_datasets.append((rid, sample))
393
- st.success(f"Loaded {rid}")
394
- except Exception as e:
395
- st.error(f"Failed to load {rid}: {e}")
396
-
397
- # Chat tab
398
- with tabs[0]:
399
- st.subheader("Chat")
400
- q = st.text_area("Ask a question about protein/DNA", value="ESM-2 임베딩은 단백질 기능 해석에 어떻게 도움되나요?")
401
-
402
- if st.button("Answer", type="primary"):
403
- with st.spinner("Thinking..."):
404
- ans, srcs = chat_answer(
405
- q,
406
- st.session_state.index,
407
- st.session_state.index_model,
408
- st.session_state.docs,
409
- st.session_state.loaded_datasets,
410
- use_web,
411
- web_k
412
- )
413
- st.write(ans)
414
-
415
- if srcs:
416
- st.markdown("#### Sources")
417
- for s in srcs:
418
- if s.get("type") == "web" and s.get("url"):
419
- st.markdown(f"- {s.get('title', 'web')}: {s.get('url')}")
420
- elif s.get("type") == "dataset":
421
- st.markdown(f"- dataset: {s.get('id')}")
422
- elif s.get("type") == "file":
423
- snippet = s.get("text", "")
424
- st.markdown(f"- file snippet: {snippet[:120]}...")
425
 
426
- # Protein tab
427
- with tabs[1]:
428
- st.subheader("Protein analysis")
429
- seq = st.text_area("Protein sequence (amino acids only)", value="MKTIIALSYIFCLVFADYKDDDDK")
 
 
 
 
 
 
 
 
430
 
431
- col1, col2 = st.columns(2)
432
- with col1:
433
- st.caption("ESM-2 embedding")
434
- if st.button("Run ESM-2", key="run_esm2"):
435
- with st.spinner("Computing ESM-2 embedding..."):
436
- out = esm2_embed(seq.strip(), esm2_id)
437
- if "error" in out:
438
- st.error(out["error"])
439
- else:
440
- st.success(f"Vector size: {out['hidden_size']}")
441
- st.json({"embedding_preview": out["embedding"][:8]})
442
-
443
- with col2:
444
- st.caption("Quick stats")
445
- s = seq.replace("\n", "").replace(" ", "").upper()
446
- length = len(s)
447
- aa_set = sorted(set(list(s)))
448
- st.write(f"Length: {length}")
449
- st.write(f"Unique AAs: {''.join(aa_set)[:30]}")
450
 
451
- # DNA tab
452
- with tabs[2]:
453
- st.subheader("DNA analysis")
454
- dseq = st.text_area("DNA sequence (ACGT only)", value="ATGCGTACGTAGCTAGCTAGCTAGGCTAGC")
455
-
456
- col3, col4 = st.columns(2)
457
- with col3:
458
- st.caption("DNA embedding")
459
- if st.button("Run DNA embed", key="run_dna"):
460
- with st.spinner("Computing DNA embedding..."):
461
- out = dna_embed(dseq.strip(), dna_id)
462
- if "error" in out:
463
- st.error(out["error"])
464
- else:
465
- st.success(f"Vector size: {out['hidden_size']}")
466
- st.json({"embedding_preview": out["embedding"][:8]})
467
-
468
- with col4:
469
- st.caption("GC content")
470
- s = dseq.upper().replace("N", "").replace(" ", "").replace("\n", "")
471
- if len(s) > 0:
472
- gc = (s.count("G") + s.count("C")) / len(s)
473
  else:
474
- gc = 0
475
- st.write(f"Length: {len(s)}")
476
- st.write(f"GC: {gc:.3f}")
 
 
 
 
 
 
 
 
 
 
 
 
477
 
478
- # Examples tab
479
- with tabs[3]:
480
- st.subheader("Examples")
481
- st.markdown("### Example questions you can ask:")
482
- st.markdown("- 업로드한 FASTA에서 특정 단백질의 기능 요약과 변이 영향 질문")
483
- st.markdown("- DNA 서열에서 프로모터 가능성과 전사인자 모티프 관련 근거 요청")
484
- st.markdown("- Enzyme active site 근접 변이의 리스크 해석 (연구 관점)")
485
- st.markdown("- ENCODE/UniProt/AlphaFold 개념 설명 요청")
486
- st.markdown("- RAG 기반으로 문서 인용과 함께 간략 답변 요청")
 
 
487
 
488
- # About tab
489
- with tabs[4]:
490
- st.subheader("About this Space")
491
- st.write("**Models suggested:**")
492
- st.write("- ESM-2 for proteins")
493
- st.write("- DNABERT-2 or Nucleotide Transformer for DNA")
494
- st.write("")
495
- st.write("**Common datasets:**")
496
- st.write("- UniProtKB, AlphaFoldDB, ENCODE, JASPAR, ClinVar")
497
- st.write("")
498
- st.write("**Features:**")
499
- st.write("- Web search powered by Brave Search API")
500
- st.write("- LLM powered by Fireworks AI")
501
- st.write("- Vector search with FAISS")
502
- st.write("")
503
- st.info(DISCLAIMER)
504
-
505
- # Run the app
506
- if __name__ == "__main__":
507
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  url = "https://api.fireworks.ai/inference/v1/chat/completions"
97
  payload = {
98
+ "model": "accounts/fireworks/models/llama-v3p1-70b-instruct",
99
  "max_tokens": max_tokens,
100
  "top_p": 1,
101
  "top_k": 40,
 
282
  answer = call_fireworks(messages, temperature=0.4, max_tokens=1200)
283
  return answer, sources
284
 
285
+ # --------------- Streamlit UI ---------------
286
 
287
+ st.set_page_config(page_title=APP_TITLE, page_icon="🧬", layout="wide")
288
+ st.title(APP_TITLE)
289
+ st.caption(DISCLAIMER)
290
+
291
+ # Check dependencies status
292
+ if not TORCH_AVAILABLE:
293
+ st.warning("⏳ PyTorch is being installed. Some features may be unavailable initially. Please refresh in a minute.")
294
+
295
+ # Initialize session state
296
+ if 'docs' not in st.session_state:
297
+ st.session_state.docs = []
298
+ if 'index' not in st.session_state:
299
+ st.session_state.index = None
300
+ if 'index_model' not in st.session_state:
301
+ st.session_state.index_model = None
302
+ if 'loaded_datasets' not in st.session_state:
303
+ st.session_state.loaded_datasets = []
304
+
305
+ # Sidebar configuration
306
+ with st.sidebar:
307
+ st.header("Keys and settings")
308
+ fw_key = st.text_input("FIREWORKS_API_KEY", value=get_secret("FIREWORKS_API_KEY", ""), type="password")
309
+ brave_key = st.text_input("BRAVE_API_KEY", value=get_secret("BRAVE_API_KEY", ""), type="password")
310
 
311
+ if fw_key:
312
+ os.environ["FIREWORKS_API_KEY"] = fw_key
313
+ if brave_key:
314
+ os.environ["BRAVE_API_KEY"] = brave_key
315
 
316
+ st.markdown("### Model selections")
317
+ esm2_id = st.text_input(
318
+ "Protein model (ESM-2)",
319
+ value="facebook/esm2_t6_8M_UR50D",
320
+ help="Try larger models like facebook/esm2_t33_650M_UR50D if resources allow."
321
+ )
322
+ dna_id = st.text_input(
323
+ "DNA model",
324
+ value="zhihan1996/DNABERT-2-117M",
325
+ help="Alternative: InstaDeepAI/nucleotide-transformer-500m-human-ref"
326
+ )
327
 
328
+ use_web = st.checkbox("Use Brave web search for context", value=True)
329
+ web_k = st.slider("Web results", 1, 10, 4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
+ st.markdown("### Datasets (optional)")
332
+ dataset_ids = st.text_area(
333
+ "Datasets to load (one per line)",
334
+ value="",
335
+ help="Enter Hugging Face dataset repo ids, e.g., 'genomics-benchmark/jaspar_motifs'"
336
+ )
337
 
338
+ st.divider()
339
+ st.markdown("Files you upload are indexed locally and used for answers.")
340
+
341
+ # Main tabs
342
+ tabs = st.tabs(["Chat", "Protein", "DNA", "Examples", "About"])
343
+
344
+ # File upload section
345
+ with st.expander("Upload files for context (txt/csv/json/fasta/vcf)", expanded=True):
346
+ uploads = st.file_uploader(
347
+ "Add files",
348
+ type=["txt", "md", "csv", "tsv", "json", "fa", "fasta", "faa", "fna", "vcf"],
349
+ accept_multiple_files=True,
350
+ key="file_uploader"
351
+ )
352
+
353
+ if uploads:
354
+ docs = []
355
+ for up in uploads:
356
+ try:
357
+ txt = load_text_from_file(up)
358
+ docs.extend(chunk_text(txt))
359
+ except Exception as e:
360
+ st.warning(f"Failed to read {up.name}: {e}")
361
+
362
+ st.session_state.docs = docs
363
+ st.caption(f"Indexed chunks: {len(docs)}")
364
 
365
+ # Build index if docs available
366
+ if docs and SENTENCE_TRANSFORMERS_AVAILABLE and FAISS_AVAILABLE:
367
+ with st.spinner("Building vector index..."):
368
+ index, emb, index_model = build_vector_index(docs)
369
+ st.session_state.index = index
370
+ st.session_state.index_model = index_model
371
+ else:
372
+ st.caption("No files uploaded yet")
373
+
374
+ # Load datasets if specified
375
+ if dataset_ids.strip() and DATASETS_AVAILABLE:
376
+ dataset_list = [x.strip() for x in dataset_ids.splitlines() if x.strip()]
377
+ if dataset_list != [d[0] for d in st.session_state.loaded_datasets]:
378
+ st.session_state.loaded_datasets = []
379
+ for rid in dataset_list:
380
+ with st.spinner(f"Loading dataset {rid}..."):
381
  try:
382
+ ds = load_dataset(rid)
383
+ sample = ""
384
+ for split in ds.keys():
385
+ try:
386
+ row = ds[split][0]
387
+ sample = json.dumps(row, ensure_ascii=False)[:500]
388
+ break
389
+ except:
390
+ pass
391
+ st.session_state.loaded_datasets.append((rid, sample))
392
+ st.success(f"Loaded {rid}")
393
  except Exception as e:
394
+ st.error(f"Failed to load {rid}: {e}")
395
+
396
+ # Chat tab
397
+ with tabs[0]:
398
+ st.subheader("Chat")
399
+ q = st.text_area("Ask a question about protein/DNA", value="ESM-2 임베딩은 단백질 기능 해석에 어떻게 도움되나요?")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
 
401
+ if st.button("Answer", type="primary"):
402
+ with st.spinner("Thinking..."):
403
+ ans, srcs = chat_answer(
404
+ q,
405
+ st.session_state.index,
406
+ st.session_state.index_model,
407
+ st.session_state.docs,
408
+ st.session_state.loaded_datasets,
409
+ use_web,
410
+ web_k
411
+ )
412
+ st.write(ans)
413
 
414
+ if srcs:
415
+ st.markdown("#### Sources")
416
+ for s in srcs:
417
+ if s.get("type") == "web" and s.get("url"):
418
+ st.markdown(f"- {s.get('title', 'web')}: {s.get('url')}")
419
+ elif s.get("type") == "dataset":
420
+ st.markdown(f"- dataset: {s.get('id')}")
421
+ elif s.get("type") == "file":
422
+ snippet = s.get("text", "")
423
+ st.markdown(f"- file snippet: {snippet[:120]}...")
424
+
425
+ # Protein tab
426
+ with tabs[1]:
427
+ st.subheader("Protein analysis")
428
+ seq = st.text_area("Protein sequence (amino acids only)", value="MKTIIALSYIFCLVFADYKDDDDK")
 
 
 
 
429
 
430
+ col1, col2 = st.columns(2)
431
+ with col1:
432
+ st.caption("ESM-2 embedding")
433
+ if st.button("Run ESM-2", key="run_esm2"):
434
+ with st.spinner("Computing ESM-2 embedding..."):
435
+ out = esm2_embed(seq.strip(), esm2_id)
436
+ if "error" in out:
437
+ st.error(out["error"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
  else:
439
+ st.success(f"Vector size: {out['hidden_size']}")
440
+ st.json({"embedding_preview": out["embedding"][:8]})
441
+
442
+ with col2:
443
+ st.caption("Quick stats")
444
+ s = seq.replace("\n", "").replace(" ", "").upper()
445
+ length = len(s)
446
+ aa_set = sorted(set(list(s)))
447
+ st.write(f"Length: {length}")
448
+ st.write(f"Unique AAs: {''.join(aa_set)[:30]}")
449
+
450
+ # DNA tab
451
+ with tabs[2]:
452
+ st.subheader("DNA analysis")
453
+ dseq = st.text_area("DNA sequence (ACGT only)", value="ATGCGTACGTAGCTAGCTAGCTAGGCTAGC")
454
 
455
+ col3, col4 = st.columns(2)
456
+ with col3:
457
+ st.caption("DNA embedding")
458
+ if st.button("Run DNA embed", key="run_dna"):
459
+ with st.spinner("Computing DNA embedding..."):
460
+ out = dna_embed(dseq.strip(), dna_id)
461
+ if "error" in out:
462
+ st.error(out["error"])
463
+ else:
464
+ st.success(f"Vector size: {out['hidden_size']}")
465
+ st.json({"embedding_preview": out["embedding"][:8]}")
466
 
467
+ with col4:
468
+ st.caption("GC content")
469
+ s = dseq.upper().replace("N", "").replace(" ", "").replace("\n", "")
470
+ if len(s) > 0:
471
+ gc = (s.count("G") + s.count("C")) / len(s)
472
+ else:
473
+ gc = 0
474
+ st.write(f"Length: {len(s)}")
475
+ st.write(f"GC: {gc:.3f}")
476
+
477
+ # Examples tab
478
+ with tabs[3]:
479
+ st.subheader("Examples")
480
+ st.markdown("### Example questions you can ask:")
481
+ st.markdown("- 업로드한 FASTA에서 특정 단백질의 기능 요약과 변이 영향 질문")
482
+ st.markdown("- DNA 서열에서 프로모터 가능성과 전사인자 모티프 관련 근거 요청")
483
+ st.markdown("- Enzyme active site 근접 변이의 리스크 해석 (연구 관점)")
484
+ st.markdown("- ENCODE/UniProt/AlphaFold 개념 설명 요청")
485
+ st.markdown("- RAG 기반으로 문서 인용과 함께 간략 답변 요청")
486
+
487
+ # About tab
488
+ with tabs[4]:
489
+ st.subheader("About this Space")
490
+ st.write("**Models suggested:**")
491
+ st.write("- ESM-2 for proteins")
492
+ st.write("- DNABERT-2 or Nucleotide Transformer for DNA")
493
+ st.write("")
494
+ st.write("**Common datasets:**")
495
+ st.write("- UniProtKB, AlphaFoldDB, ENCODE, JASPAR, ClinVar")
496
+ st.write("")
497
+ st.write("**Features:**")
498
+ st.write("- Web search powered by Brave Search API")
499
+ st.write("- LLM powered by Fireworks AI")
500
+ st.write("- Vector search with FAISS")
501
+ st.write("")
502
+ st.info(DISCLAIMER)