|
|
import os |
|
|
import json |
|
|
import time |
|
|
from typing import List, Dict, Tuple |
|
|
|
|
|
import streamlit as st |
|
|
import requests |
|
|
|
|
|
|
|
|
try: |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM |
|
|
TORCH_AVAILABLE = True |
|
|
except Exception: |
|
|
TORCH_AVAILABLE = False |
|
|
|
|
|
try: |
|
|
from datasets import load_dataset |
|
|
DATASETS_AVAILABLE = True |
|
|
except Exception: |
|
|
DATASETS_AVAILABLE = False |
|
|
|
|
|
try: |
|
|
from sentence_transformers import SentenceTransformer |
|
|
SENTENCE_TRANSFORMERS_AVAILABLE = True |
|
|
except Exception: |
|
|
SENTENCE_TRANSFORMERS_AVAILABLE = False |
|
|
|
|
|
try: |
|
|
import faiss |
|
|
FAISS_AVAILABLE = True |
|
|
except Exception: |
|
|
FAISS_AVAILABLE = False |
|
|
|
|
|
try: |
|
|
from Bio import SeqIO |
|
|
BIOPYTHON_AVAILABLE = True |
|
|
except Exception: |
|
|
BIOPYTHON_AVAILABLE = False |
|
|
|
|
|
|
|
|
APP_TITLE = "BioSeq Chat: Protein & DNA Assistant" |
|
|
DISCLAIMER = ( |
|
|
"This tool is for research/education and is not a medical device. " |
|
|
"Do not use outputs for diagnosis or treatment decisions." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def get_secret(name: str, fallback: str = "") -> str: |
|
|
"""Get secret from st.secrets, environment, or fallback""" |
|
|
try: |
|
|
if hasattr(st, 'secrets'): |
|
|
return st.secrets.get(name, os.environ.get(name, fallback)) |
|
|
except: |
|
|
pass |
|
|
return os.environ.get(name, fallback) |
|
|
|
|
|
def brave_search(query: str, count: int = 5) -> List[Dict]: |
|
|
"""Search using Brave Search API""" |
|
|
key = get_secret("BRAVE_API_KEY", "") |
|
|
if not key: |
|
|
return [{"title": "BRAVE_API_KEY is missing", |
|
|
"url": "", |
|
|
"snippet": "Set BRAVE_API_KEY in Space secrets or sidebar to enable web search."}] |
|
|
|
|
|
url = "https://api.search.brave.com/res/v1/web/search" |
|
|
headers = { |
|
|
"Accept": "application/json", |
|
|
"X-Subscription-Token": key, |
|
|
"Accept-Encoding": "gzip" |
|
|
} |
|
|
params = {"q": query, "count": count, "country": "us"} |
|
|
|
|
|
try: |
|
|
r = requests.get(url, headers=headers, params=params, timeout=15) |
|
|
r.raise_for_status() |
|
|
data = r.json() |
|
|
results = [] |
|
|
for item in data.get("web", {}).get("results", [])[:count]: |
|
|
results.append({ |
|
|
"title": item.get("title", ""), |
|
|
"url": item.get("url", ""), |
|
|
"snippet": item.get("description", ""), |
|
|
}) |
|
|
return results if results else [{"title": "No results", "url": "", "snippet": "Query returned no results."}] |
|
|
except Exception as e: |
|
|
return [{"title": "Search error", "url": "", "snippet": str(e)}] |
|
|
|
|
|
def call_fireworks(messages: List[Dict], temperature: float = 0.6, max_tokens: int = 1024) -> str: |
|
|
"""Call Fireworks AI chat completion API""" |
|
|
api_key = get_secret("FIREWORKS_API_KEY", "") |
|
|
if not api_key: |
|
|
return "FIREWORKS_API_KEY is missing. Set it in Secrets or the sidebar." |
|
|
|
|
|
url = "https://api.fireworks.ai/inference/v1/chat/completions" |
|
|
payload = { |
|
|
"model": "accounts/fireworks/models/llama-v3p1-70b-instruct", |
|
|
"max_tokens": max_tokens, |
|
|
"top_p": 1, |
|
|
"top_k": 40, |
|
|
"presence_penalty": 0, |
|
|
"frequency_penalty": 0, |
|
|
"temperature": temperature, |
|
|
"messages": messages |
|
|
} |
|
|
headers = { |
|
|
"Accept": "application/json", |
|
|
"Content-Type": "application/json", |
|
|
"Authorization": f"Bearer {api_key}" |
|
|
} |
|
|
|
|
|
try: |
|
|
r = requests.post(url, headers=headers, data=json.dumps(payload), timeout=60) |
|
|
r.raise_for_status() |
|
|
data = r.json() |
|
|
return data["choices"][0]["message"]["content"] |
|
|
except Exception as e: |
|
|
return f"[Fireworks API error] {e}" |
|
|
|
|
|
def load_text_from_file(upload) -> str: |
|
|
"""Load text from uploaded file""" |
|
|
name = upload.name.lower() |
|
|
content = upload.read() |
|
|
|
|
|
try: |
|
|
text = content.decode("utf-8", errors="ignore") |
|
|
except: |
|
|
text = str(content) |
|
|
|
|
|
|
|
|
if name.endswith((".fa", ".fasta", ".faa", ".fna")) and BIOPYTHON_AVAILABLE: |
|
|
upload.seek(0) |
|
|
try: |
|
|
records = list(SeqIO.parse(upload, "fasta")) |
|
|
seqs = [] |
|
|
for r in records: |
|
|
seqs.append(f">{r.id}\n{str(r.seq)}") |
|
|
return "\n\n".join(seqs) |
|
|
except: |
|
|
pass |
|
|
|
|
|
return text |
|
|
|
|
|
def build_vector_index(texts: List[str], embedder_name: str = "sentence-transformers/all-MiniLM-L6-v2"): |
|
|
"""Build FAISS vector index from texts""" |
|
|
if not SENTENCE_TRANSFORMERS_AVAILABLE or not FAISS_AVAILABLE: |
|
|
return None, None, None |
|
|
|
|
|
try: |
|
|
model = SentenceTransformer(embedder_name) |
|
|
emb = model.encode(texts, show_progress_bar=False, normalize_embeddings=True) |
|
|
dim = emb.shape[1] |
|
|
index = faiss.IndexFlatIP(dim) |
|
|
index.add(emb.astype("float32")) |
|
|
return index, emb, model |
|
|
except Exception as e: |
|
|
st.warning(f"Failed to build index: {e}") |
|
|
return None, None, None |
|
|
|
|
|
def search_index(query: str, index, model, texts: List[str], k: int = 4): |
|
|
"""Search vector index""" |
|
|
if index is None or model is None: |
|
|
return [] |
|
|
|
|
|
try: |
|
|
q = model.encode([query], normalize_embeddings=True) |
|
|
D, I = index.search(q.astype("float32"), k) |
|
|
hits = [] |
|
|
for idx, score in zip(I[0], D[0]): |
|
|
if 0 <= idx < len(texts): |
|
|
hits.append({"score": float(score), "text": texts[idx]}) |
|
|
return hits |
|
|
except: |
|
|
return [] |
|
|
|
|
|
def esm2_embed(seq: str, model_id: str = "facebook/esm2_t6_8M_UR50D") -> Dict: |
|
|
"""Generate ESM-2 embedding for protein sequence""" |
|
|
if not TORCH_AVAILABLE: |
|
|
return {"error": "Transformers/torch not available. Please wait for dependencies to install."} |
|
|
|
|
|
try: |
|
|
from transformers import AutoTokenizer, AutoModelForMaskedLM |
|
|
import torch |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) |
|
|
model = AutoModelForMaskedLM.from_pretrained(model_id, trust_remote_code=True) |
|
|
model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
|
toks = tokenizer(seq, return_tensors="pt") |
|
|
out = model(**toks, output_hidden_states=True) |
|
|
hidden = out.hidden_states[-1].mean(dim=1).squeeze(0) |
|
|
vec = hidden.detach().cpu().numpy() |
|
|
return {"embedding": vec.tolist(), "hidden_size": vec.shape[0]} |
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|
|
|
def dna_embed(seq: str, model_id: str = "zhihan1996/DNABERT-2-117M") -> Dict: |
|
|
"""Generate DNA embedding""" |
|
|
if not TORCH_AVAILABLE: |
|
|
return {"error": "Transformers/torch not available. Please wait for dependencies to install."} |
|
|
|
|
|
try: |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
import torch |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) |
|
|
model = AutoModel.from_pretrained(model_id, trust_remote_code=True) |
|
|
model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
|
toks = tokenizer(seq, return_tensors="pt", truncation=True, max_length=4096) |
|
|
out = model(**toks, output_hidden_states=True) |
|
|
hidden = out.last_hidden_state.mean(dim=1).squeeze(0) |
|
|
vec = hidden.detach().cpu().numpy() |
|
|
return {"embedding": vec.tolist(), "hidden_size": vec.shape[0]} |
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|
|
|
def chunk_text(text: str, chunk_size: int = 1200, overlap: int = 200) -> List[str]: |
|
|
"""Chunk text with overlap""" |
|
|
text = text.replace("\r\n", "\n") |
|
|
chunks = [] |
|
|
start = 0 |
|
|
|
|
|
while start < len(text): |
|
|
end = min(len(text), start + chunk_size) |
|
|
chunks.append(text[start:end]) |
|
|
if end >= len(text): |
|
|
break |
|
|
start = end - overlap |
|
|
|
|
|
return chunks |
|
|
|
|
|
def build_context(user_query: str, index, index_model, docs: List[str], loaded_datasets: List, use_web: bool, web_k: int) -> Tuple[str, List[Dict]]: |
|
|
"""Build context from various sources""" |
|
|
pieces = [] |
|
|
sources = [] |
|
|
|
|
|
|
|
|
if index is not None and index_model is not None and docs: |
|
|
hits = search_index(user_query, index, index_model, docs, k=4) |
|
|
for h in hits: |
|
|
pieces.append(f"[FILE] {h['text'][:800]}") |
|
|
sources.append({"type": "file", "text": h["text"][:200]}) |
|
|
|
|
|
|
|
|
for rid, sample in loaded_datasets: |
|
|
if sample: |
|
|
pieces.append(f"[DATASET {rid}] {sample}") |
|
|
sources.append({"type": "dataset", "id": rid}) |
|
|
|
|
|
|
|
|
if use_web: |
|
|
results = brave_search(user_query, count=web_k) |
|
|
for r in results: |
|
|
snippet = r.get("snippet", "") |
|
|
url = r.get("url", "") |
|
|
title = r.get("title", "") |
|
|
pieces.append(f"[WEB] {title}\n{snippet}\n{url}") |
|
|
sources.append({"type": "web", "title": title, "url": url}) |
|
|
|
|
|
context = "\n\n---\n\n".join(pieces)[:6000] |
|
|
return context, sources |
|
|
|
|
|
def chat_answer(user_query: str, index, index_model, docs: List[str], loaded_datasets: List, use_web: bool, web_k: int) -> Tuple[str, List[Dict]]: |
|
|
"""Generate chat answer with context""" |
|
|
context, sources = build_context(user_query, index, index_model, docs, loaded_datasets, use_web, web_k) |
|
|
system = ( |
|
|
"You are a concise, careful bioinformatics assistant for protein and DNA. " |
|
|
"Answer with factual, verifiable statements. " |
|
|
"When uncertain, say so briefly. " |
|
|
"Never give medical advice. Provide short references as plain URLs or titles if present in context. " |
|
|
"User uploads and web/dataset snippets are provided as context below." |
|
|
) |
|
|
prompt = f"Context:\n{context}\n\nUser question:\n{user_query}\n\nAnswer in Korean if the user used Korean; otherwise match user language." |
|
|
messages = [ |
|
|
{"role": "system", "content": system}, |
|
|
{"role": "user", "content": prompt} |
|
|
] |
|
|
answer = call_fireworks(messages, temperature=0.4, max_tokens=1200) |
|
|
return answer, sources |
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config(page_title=APP_TITLE, page_icon="๐งฌ", layout="wide") |
|
|
st.title(APP_TITLE) |
|
|
st.caption(DISCLAIMER) |
|
|
|
|
|
|
|
|
if not TORCH_AVAILABLE: |
|
|
st.warning("โณ PyTorch is being installed. Some features may be unavailable initially. Please refresh in a minute.") |
|
|
|
|
|
|
|
|
if 'docs' not in st.session_state: |
|
|
st.session_state.docs = [] |
|
|
if 'index' not in st.session_state: |
|
|
st.session_state.index = None |
|
|
if 'index_model' not in st.session_state: |
|
|
st.session_state.index_model = None |
|
|
if 'loaded_datasets' not in st.session_state: |
|
|
st.session_state.loaded_datasets = [] |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.header("Keys and settings") |
|
|
fw_key = st.text_input("FIREWORKS_API_KEY", value=get_secret("FIREWORKS_API_KEY", ""), type="password") |
|
|
brave_key = st.text_input("BRAVE_API_KEY", value=get_secret("BRAVE_API_KEY", ""), type="password") |
|
|
|
|
|
if fw_key: |
|
|
os.environ["FIREWORKS_API_KEY"] = fw_key |
|
|
if brave_key: |
|
|
os.environ["BRAVE_API_KEY"] = brave_key |
|
|
|
|
|
st.markdown("### Model selections") |
|
|
esm2_id = st.text_input( |
|
|
"Protein model (ESM-2)", |
|
|
value="facebook/esm2_t6_8M_UR50D", |
|
|
help="Try larger models like facebook/esm2_t33_650M_UR50D if resources allow." |
|
|
) |
|
|
dna_id = st.text_input( |
|
|
"DNA model", |
|
|
value="zhihan1996/DNABERT-2-117M", |
|
|
help="Alternative: InstaDeepAI/nucleotide-transformer-500m-human-ref" |
|
|
) |
|
|
|
|
|
use_web = st.checkbox("Use Brave web search for context", value=True) |
|
|
web_k = st.slider("Web results", 1, 10, 4) |
|
|
|
|
|
st.markdown("### Datasets (optional)") |
|
|
dataset_ids = st.text_area( |
|
|
"Datasets to load (one per line)", |
|
|
value="", |
|
|
help="Enter Hugging Face dataset repo ids, e.g., 'genomics-benchmark/jaspar_motifs'" |
|
|
) |
|
|
|
|
|
st.divider() |
|
|
st.markdown("Files you upload are indexed locally and used for answers.") |
|
|
|
|
|
|
|
|
tabs = st.tabs(["Chat", "Protein", "DNA", "Examples", "About"]) |
|
|
|
|
|
|
|
|
with st.expander("Upload files for context (txt/csv/json/fasta/vcf)", expanded=True): |
|
|
uploads = st.file_uploader( |
|
|
"Add files", |
|
|
type=["txt", "md", "csv", "tsv", "json", "fa", "fasta", "faa", "fna", "vcf"], |
|
|
accept_multiple_files=True, |
|
|
key="file_uploader" |
|
|
) |
|
|
|
|
|
if uploads: |
|
|
docs = [] |
|
|
for up in uploads: |
|
|
try: |
|
|
txt = load_text_from_file(up) |
|
|
docs.extend(chunk_text(txt)) |
|
|
except Exception as e: |
|
|
st.warning(f"Failed to read {up.name}: {e}") |
|
|
|
|
|
st.session_state.docs = docs |
|
|
st.caption(f"Indexed chunks: {len(docs)}") |
|
|
|
|
|
|
|
|
if docs and SENTENCE_TRANSFORMERS_AVAILABLE and FAISS_AVAILABLE: |
|
|
with st.spinner("Building vector index..."): |
|
|
index, emb, index_model = build_vector_index(docs) |
|
|
st.session_state.index = index |
|
|
st.session_state.index_model = index_model |
|
|
else: |
|
|
st.caption("No files uploaded yet") |
|
|
|
|
|
|
|
|
if dataset_ids.strip() and DATASETS_AVAILABLE: |
|
|
dataset_list = [x.strip() for x in dataset_ids.splitlines() if x.strip()] |
|
|
if dataset_list != [d[0] for d in st.session_state.loaded_datasets]: |
|
|
st.session_state.loaded_datasets = [] |
|
|
for rid in dataset_list: |
|
|
with st.spinner(f"Loading dataset {rid}..."): |
|
|
try: |
|
|
ds = load_dataset(rid) |
|
|
sample = "" |
|
|
for split in ds.keys(): |
|
|
try: |
|
|
row = ds[split][0] |
|
|
sample = json.dumps(row, ensure_ascii=False)[:500] |
|
|
break |
|
|
except: |
|
|
pass |
|
|
st.session_state.loaded_datasets.append((rid, sample)) |
|
|
st.success(f"Loaded {rid}") |
|
|
except Exception as e: |
|
|
st.error(f"Failed to load {rid}: {e}") |
|
|
|
|
|
|
|
|
with tabs[0]: |
|
|
st.subheader("Chat") |
|
|
q = st.text_area("Ask a question about protein/DNA", value="ESM-2 ์๋ฒ ๋ฉ์ ๋จ๋ฐฑ์ง ๊ธฐ๋ฅ ํด์์ ์ด๋ป๊ฒ ๋์๋๋์?") |
|
|
|
|
|
if st.button("Answer", type="primary"): |
|
|
with st.spinner("Thinking..."): |
|
|
ans, srcs = chat_answer( |
|
|
q, |
|
|
st.session_state.index, |
|
|
st.session_state.index_model, |
|
|
st.session_state.docs, |
|
|
st.session_state.loaded_datasets, |
|
|
use_web, |
|
|
web_k |
|
|
) |
|
|
st.write(ans) |
|
|
|
|
|
if srcs: |
|
|
st.markdown("#### Sources") |
|
|
for s in srcs: |
|
|
if s.get("type") == "web" and s.get("url"): |
|
|
st.markdown(f"- {s.get('title', 'web')}: {s.get('url')}") |
|
|
elif s.get("type") == "dataset": |
|
|
st.markdown(f"- dataset: {s.get('id')}") |
|
|
elif s.get("type") == "file": |
|
|
snippet = s.get("text", "") |
|
|
st.markdown(f"- file snippet: {snippet[:120]}...") |
|
|
|
|
|
|
|
|
with tabs[1]: |
|
|
st.subheader("Protein analysis") |
|
|
seq = st.text_area("Protein sequence (amino acids only)", value="MKTIIALSYIFCLVFADYKDDDDK") |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
with col1: |
|
|
st.caption("ESM-2 embedding") |
|
|
if st.button("Run ESM-2", key="run_esm2"): |
|
|
with st.spinner("Computing ESM-2 embedding..."): |
|
|
out = esm2_embed(seq.strip(), esm2_id) |
|
|
if "error" in out: |
|
|
st.error(out["error"]) |
|
|
else: |
|
|
st.success(f"Vector size: {out['hidden_size']}") |
|
|
st.json({"embedding_preview": out["embedding"][:8]}) |
|
|
|
|
|
with col2: |
|
|
st.caption("Quick stats") |
|
|
s = seq.replace("\n", "").replace(" ", "").upper() |
|
|
length = len(s) |
|
|
aa_set = sorted(set(list(s))) |
|
|
st.write(f"Length: {length}") |
|
|
st.write(f"Unique AAs: {''.join(aa_set)[:30]}") |
|
|
|
|
|
|
|
|
with tabs[2]: |
|
|
st.subheader("DNA analysis") |
|
|
dseq = st.text_area("DNA sequence (ACGT only)", value="ATGCGTACGTAGCTAGCTAGCTAGGCTAGC") |
|
|
|
|
|
col3, col4 = st.columns(2) |
|
|
with col3: |
|
|
st.caption("DNA embedding") |
|
|
if st.button("Run DNA embed", key="run_dna"): |
|
|
with st.spinner("Computing DNA embedding..."): |
|
|
out = dna_embed(dseq.strip(), dna_id) |
|
|
if "error" in out: |
|
|
st.error(out["error"]) |
|
|
else: |
|
|
st.success(f"Vector size: {out['hidden_size']}") |
|
|
st.json({"embedding_preview": out["embedding"][:8]}") |
|
|
|
|
|
with col4: |
|
|
st.caption("GC content") |
|
|
s = dseq.upper().replace("N", "").replace(" ", "").replace("\n", "") |
|
|
if len(s) > 0: |
|
|
gc = (s.count("G") + s.count("C")) / len(s) |
|
|
else: |
|
|
gc = 0 |
|
|
st.write(f"Length: {len(s)}") |
|
|
st.write(f"GC: {gc:.3f}") |
|
|
|
|
|
# Examples tab |
|
|
with tabs[3]: |
|
|
st.subheader("Examples") |
|
|
st.markdown(" |
|
|
st.markdown("- ์
๋ก๋ํ FASTA์์ ํน์ ๋จ๋ฐฑ์ง์ ๊ธฐ๋ฅ ์์ฝ๊ณผ ๋ณ์ด ์ํฅ ์ง๋ฌธ") |
|
|
st.markdown("- DNA ์์ด์์ ํ๋ก๋ชจํฐ ๊ฐ๋ฅ์ฑ๊ณผ ์ ์ฌ์ธ์ ๋ชจํฐํ ๊ด๋ จ ๊ทผ๊ฑฐ ์์ฒญ") |
|
|
st.markdown("- Enzyme active site ๊ทผ์ ๋ณ์ด์ ๋ฆฌ์คํฌ ํด์ (์ฐ๊ตฌ ๊ด์ )") |
|
|
st.markdown("- ENCODE/UniProt/AlphaFold ๊ฐ๋
์ค๋ช
์์ฒญ") |
|
|
st.markdown("- RAG ๊ธฐ๋ฐ์ผ๋ก ๋ฌธ์ ์ธ์ฉ๊ณผ ํจ๊ป ๊ฐ๋ต ๋ต๋ณ ์์ฒญ") |
|
|
|
|
|
|
|
|
with tabs[4]: |
|
|
st.subheader("About this Space") |
|
|
st.write("**Models suggested:**") |
|
|
st.write("- ESM-2 for proteins") |
|
|
st.write("- DNABERT-2 or Nucleotide Transformer for DNA") |
|
|
st.write("") |
|
|
st.write("**Common datasets:**") |
|
|
st.write("- UniProtKB, AlphaFoldDB, ENCODE, JASPAR, ClinVar") |
|
|
st.write("") |
|
|
st.write("**Features:**") |
|
|
st.write("- Web search powered by Brave Search API") |
|
|
st.write("- LLM powered by Fireworks AI") |
|
|
st.write("- Vector search with FAISS") |
|
|
st.write("") |
|
|
st.info(DISCLAIMER) |