|
|
import os |
|
|
import json |
|
|
from typing import List, Dict, Tuple |
|
|
|
|
|
import streamlit as st |
|
|
import requests |
|
|
|
|
|
|
|
|
try: |
|
|
import torch |
|
|
TORCH_AVAILABLE = True |
|
|
except ImportError: |
|
|
TORCH_AVAILABLE = False |
|
|
print("[WARNING] torch not available") |
|
|
|
|
|
try: |
|
|
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM |
|
|
TRANSFORMERS_AVAILABLE = True |
|
|
except ImportError: |
|
|
TRANSFORMERS_AVAILABLE = False |
|
|
print("[WARNING] transformers not available") |
|
|
|
|
|
try: |
|
|
from datasets import load_dataset |
|
|
DATASETS_AVAILABLE = True |
|
|
except ImportError: |
|
|
DATASETS_AVAILABLE = False |
|
|
print("[WARNING] datasets not available") |
|
|
|
|
|
try: |
|
|
from sentence_transformers import SentenceTransformer |
|
|
SENTENCE_TRANSFORMERS_AVAILABLE = True |
|
|
except ImportError: |
|
|
SENTENCE_TRANSFORMERS_AVAILABLE = False |
|
|
print("[WARNING] sentence_transformers not available") |
|
|
|
|
|
try: |
|
|
import faiss |
|
|
FAISS_AVAILABLE = True |
|
|
except ImportError: |
|
|
FAISS_AVAILABLE = False |
|
|
print("[WARNING] faiss not available") |
|
|
|
|
|
try: |
|
|
from Bio import SeqIO |
|
|
BIOPYTHON_AVAILABLE = True |
|
|
except ImportError: |
|
|
BIOPYTHON_AVAILABLE = False |
|
|
print("[WARNING] biopython not available") |
|
|
|
|
|
|
|
|
APP_TITLE = "BioSeq Chat: Protein & DNA Assistant" |
|
|
DISCLAIMER = "This tool is for research/education and is not a medical device. Do not use outputs for diagnosis or treatment decisions." |
|
|
|
|
|
|
|
|
|
|
|
def get_secret(name: str, fallback: str = "") -> str: |
|
|
"""Get secret from st.secrets or environment""" |
|
|
try: |
|
|
|
|
|
if hasattr(st, 'secrets') and name in st.secrets: |
|
|
return st.secrets[name] |
|
|
except: |
|
|
pass |
|
|
|
|
|
return os.environ.get(name, fallback) |
|
|
|
|
|
def brave_search(query: str, count: int = 5) -> List[Dict]: |
|
|
"""Brave Search API""" |
|
|
key = get_secret("BRAVE_API_KEY", "") |
|
|
if not key: |
|
|
return [{ |
|
|
"title": "BRAVE_API_KEY missing", |
|
|
"url": "", |
|
|
"snippet": "Set BRAVE_API_KEY in Space secrets or sidebar" |
|
|
}] |
|
|
|
|
|
url = "https://api.search.brave.com/res/v1/web/search" |
|
|
headers = { |
|
|
"Accept": "application/json", |
|
|
"X-Subscription-Token": key |
|
|
} |
|
|
params = {"q": query, "count": count} |
|
|
|
|
|
try: |
|
|
r = requests.get(url, headers=headers, params=params, timeout=15) |
|
|
r.raise_for_status() |
|
|
data = r.json() |
|
|
results = [] |
|
|
for item in data.get("web", {}).get("results", [])[:count]: |
|
|
results.append({ |
|
|
"title": item.get("title", ""), |
|
|
"url": item.get("url", ""), |
|
|
"snippet": item.get("description", "") |
|
|
}) |
|
|
return results if results else [{"title": "No results", "url": "", "snippet": ""}] |
|
|
except Exception as e: |
|
|
return [{"title": "Error", "url": "", "snippet": str(e)}] |
|
|
|
|
|
def call_llm(messages: List[Dict], temperature: float = 0.6, max_tokens: int = 1024) -> str: |
|
|
"""Call Fireworks AI API""" |
|
|
api_key = get_secret("FIREWORKS_API_KEY", "") |
|
|
if not api_key: |
|
|
return "FIREWORKS_API_KEY missing. Set it in Secrets or sidebar." |
|
|
|
|
|
url = "https://api.fireworks.ai/inference/v1/chat/completions" |
|
|
payload = { |
|
|
"model": "accounts/fireworks/models/llama-v3p1-70b-instruct", |
|
|
"messages": messages, |
|
|
"max_tokens": max_tokens, |
|
|
"temperature": temperature, |
|
|
"top_p": 1, |
|
|
"frequency_penalty": 0, |
|
|
"presence_penalty": 0 |
|
|
} |
|
|
headers = { |
|
|
"Content-Type": "application/json", |
|
|
"Authorization": f"Bearer {api_key}" |
|
|
} |
|
|
|
|
|
try: |
|
|
r = requests.post(url, headers=headers, json=payload, timeout=60) |
|
|
r.raise_for_status() |
|
|
return r.json()["choices"][0]["message"]["content"] |
|
|
except Exception as e: |
|
|
return f"[LLM Error] {e}" |
|
|
|
|
|
def load_file_text(upload) -> str: |
|
|
"""Load text from uploaded file""" |
|
|
name = upload.name.lower() |
|
|
|
|
|
try: |
|
|
content = upload.read() |
|
|
text = content.decode("utf-8", errors="ignore") |
|
|
except: |
|
|
return "" |
|
|
|
|
|
|
|
|
if name.endswith((".fa", ".fasta", ".faa", ".fna")) and BIOPYTHON_AVAILABLE: |
|
|
try: |
|
|
upload.seek(0) |
|
|
records = list(SeqIO.parse(upload, "fasta")) |
|
|
seqs = [f">{r.id}\n{str(r.seq)}" for r in records] |
|
|
return "\n\n".join(seqs) |
|
|
except: |
|
|
pass |
|
|
|
|
|
return text |
|
|
|
|
|
def chunk_text(text: str, size: int = 1200, overlap: int = 200) -> List[str]: |
|
|
"""Split text into chunks""" |
|
|
chunks = [] |
|
|
start = 0 |
|
|
text_len = len(text) |
|
|
|
|
|
while start < text_len: |
|
|
end = min(start + size, text_len) |
|
|
chunks.append(text[start:end]) |
|
|
if end >= text_len: |
|
|
break |
|
|
start = end - overlap |
|
|
|
|
|
return chunks |
|
|
|
|
|
def build_index(texts: List[str]): |
|
|
"""Build vector index""" |
|
|
if not SENTENCE_TRANSFORMERS_AVAILABLE or not FAISS_AVAILABLE: |
|
|
return None, None |
|
|
|
|
|
try: |
|
|
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") |
|
|
embeddings = model.encode(texts, show_progress_bar=False) |
|
|
|
|
|
dim = embeddings.shape[1] |
|
|
index = faiss.IndexFlatIP(dim) |
|
|
index.add(embeddings.astype("float32")) |
|
|
|
|
|
return index, model |
|
|
except Exception as e: |
|
|
st.warning(f"Index build failed: {e}") |
|
|
return None, None |
|
|
|
|
|
def search_index(query: str, index, model, texts: List[str], k: int = 4) -> List[Dict]: |
|
|
"""Search vector index""" |
|
|
if index is None or model is None: |
|
|
return [] |
|
|
|
|
|
try: |
|
|
q_emb = model.encode([query]) |
|
|
D, I = index.search(q_emb.astype("float32"), k) |
|
|
|
|
|
results = [] |
|
|
for idx, score in zip(I[0], D[0]): |
|
|
if 0 <= idx < len(texts): |
|
|
results.append({ |
|
|
"score": float(score), |
|
|
"text": texts[idx] |
|
|
}) |
|
|
return results |
|
|
except: |
|
|
return [] |
|
|
|
|
|
def esm2_embed(seq: str, model_name: str = "facebook/esm2_t6_8M_UR50D") -> Dict: |
|
|
"""ESM-2 protein embedding""" |
|
|
if not TORCH_AVAILABLE or not TRANSFORMERS_AVAILABLE: |
|
|
return {"error": "PyTorch/Transformers not available"} |
|
|
|
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForMaskedLM.from_pretrained(model_name) |
|
|
model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
|
inputs = tokenizer(seq, return_tensors="pt") |
|
|
outputs = model(**inputs, output_hidden_states=True) |
|
|
hidden = outputs.hidden_states[-1].mean(dim=1).squeeze(0) |
|
|
vec = hidden.numpy() |
|
|
|
|
|
return { |
|
|
"embedding": vec.tolist(), |
|
|
"size": vec.shape[0] |
|
|
} |
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|
|
|
def dna_embed(seq: str, model_name: str = "zhihan1996/DNABERT-2-117M") -> Dict: |
|
|
"""DNA embedding""" |
|
|
if not TORCH_AVAILABLE or not TRANSFORMERS_AVAILABLE: |
|
|
return {"error": "PyTorch/Transformers not available"} |
|
|
|
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
model = AutoModel.from_pretrained(model_name, trust_remote_code=True) |
|
|
model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
|
inputs = tokenizer(seq, return_tensors="pt", truncation=True, max_length=512) |
|
|
outputs = model(**inputs) |
|
|
hidden = outputs.last_hidden_state.mean(dim=1).squeeze(0) |
|
|
vec = hidden.numpy() |
|
|
|
|
|
return { |
|
|
"embedding": vec.tolist(), |
|
|
"size": vec.shape[0] |
|
|
} |
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|
|
|
def build_context(query: str, docs: List[str], index, model, use_web: bool, web_k: int) -> Tuple[str, List[Dict]]: |
|
|
"""Build context from sources""" |
|
|
pieces = [] |
|
|
sources = [] |
|
|
|
|
|
|
|
|
if index and model and docs: |
|
|
hits = search_index(query, index, model, docs, k=4) |
|
|
for h in hits: |
|
|
pieces.append(f"[FILE] {h['text'][:500]}") |
|
|
sources.append({"type": "file", "text": h['text'][:100]}) |
|
|
|
|
|
|
|
|
if use_web: |
|
|
results = brave_search(query, count=web_k) |
|
|
for r in results: |
|
|
pieces.append(f"[WEB] {r['title']}\n{r['snippet']}") |
|
|
sources.append({"type": "web", "title": r['title'], "url": r['url']}) |
|
|
|
|
|
context = "\n\n---\n\n".join(pieces)[:4000] |
|
|
return context, sources |
|
|
|
|
|
def answer_question(query: str, context: str) -> str: |
|
|
"""Generate answer""" |
|
|
system = ( |
|
|
"You are a bioinformatics assistant. Be concise and factual. " |
|
|
"Never give medical advice. Answer in the user's language." |
|
|
) |
|
|
|
|
|
user_msg = f"Context:\n{context}\n\nQuestion: {query}" |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": system}, |
|
|
{"role": "user", "content": user_msg} |
|
|
] |
|
|
|
|
|
return call_llm(messages, temperature=0.4, max_tokens=1000) |
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config(page_title=APP_TITLE, page_icon="π§¬", layout="wide") |
|
|
st.title(APP_TITLE) |
|
|
st.caption(DISCLAIMER) |
|
|
|
|
|
|
|
|
if "docs" not in st.session_state: |
|
|
st.session_state.docs = [] |
|
|
if "index" not in st.session_state: |
|
|
st.session_state.index = None |
|
|
if "model" not in st.session_state: |
|
|
st.session_state.model = None |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.header("Configuration") |
|
|
|
|
|
fw_key = st.text_input( |
|
|
"FIREWORKS_API_KEY", |
|
|
value=get_secret("FIREWORKS_API_KEY", ""), |
|
|
type="password" |
|
|
) |
|
|
brave_key = st.text_input( |
|
|
"BRAVE_API_KEY", |
|
|
value=get_secret("BRAVE_API_KEY", ""), |
|
|
type="password" |
|
|
) |
|
|
|
|
|
if fw_key: |
|
|
os.environ["FIREWORKS_API_KEY"] = fw_key |
|
|
if brave_key: |
|
|
os.environ["BRAVE_API_KEY"] = brave_key |
|
|
|
|
|
st.divider() |
|
|
|
|
|
esm_model = st.text_input( |
|
|
"ESM-2 Model", |
|
|
value="facebook/esm2_t6_8M_UR50D" |
|
|
) |
|
|
dna_model = st.text_input( |
|
|
"DNA Model", |
|
|
value="zhihan1996/DNABERT-2-117M" |
|
|
) |
|
|
|
|
|
use_web = st.checkbox("Enable web search", value=True) |
|
|
web_results = st.slider("Web results", 1, 10, 3) |
|
|
|
|
|
|
|
|
tab1, tab2, tab3, tab4 = st.tabs(["Chat", "Protein", "DNA", "About"]) |
|
|
|
|
|
|
|
|
with st.expander("π Upload Files", expanded=True): |
|
|
files = st.file_uploader( |
|
|
"Upload text/FASTA files", |
|
|
type=["txt", "fa", "fasta", "csv", "json"], |
|
|
accept_multiple_files=True |
|
|
) |
|
|
|
|
|
if files: |
|
|
docs = [] |
|
|
for f in files: |
|
|
try: |
|
|
text = load_file_text(f) |
|
|
if text: |
|
|
docs.extend(chunk_text(text)) |
|
|
except Exception as e: |
|
|
st.error(f"Error reading {f.name}: {e}") |
|
|
|
|
|
if docs: |
|
|
st.session_state.docs = docs |
|
|
st.success(f"Loaded {len(docs)} chunks") |
|
|
|
|
|
if SENTENCE_TRANSFORMERS_AVAILABLE and FAISS_AVAILABLE: |
|
|
with st.spinner("Building index..."): |
|
|
index, model = build_index(docs) |
|
|
if index: |
|
|
st.session_state.index = index |
|
|
st.session_state.model = model |
|
|
|
|
|
|
|
|
with tab1: |
|
|
st.subheader("π¬ Chat Assistant") |
|
|
|
|
|
question = st.text_area( |
|
|
"Ask about proteins, DNA, or bioinformatics:", |
|
|
value="What is the role of ESM-2 embeddings in protein analysis?", |
|
|
height=100 |
|
|
) |
|
|
|
|
|
if st.button("Get Answer", type="primary"): |
|
|
if not get_secret("FIREWORKS_API_KEY"): |
|
|
st.error("Please set FIREWORKS_API_KEY") |
|
|
else: |
|
|
with st.spinner("Thinking..."): |
|
|
context, sources = build_context( |
|
|
question, |
|
|
st.session_state.docs, |
|
|
st.session_state.index, |
|
|
st.session_state.model, |
|
|
use_web, |
|
|
web_results |
|
|
) |
|
|
|
|
|
answer = answer_question(question, context) |
|
|
|
|
|
st.markdown("### Answer") |
|
|
st.write(answer) |
|
|
|
|
|
if sources: |
|
|
st.markdown("### Sources") |
|
|
for s in sources: |
|
|
if s["type"] == "web": |
|
|
st.write(f"- π [{s['title']}]({s['url']})") |
|
|
elif s["type"] == "file": |
|
|
st.write(f"- π File: {s['text'][:80]}...") |
|
|
|
|
|
|
|
|
with tab2: |
|
|
st.subheader("𧬠Protein Analysis") |
|
|
|
|
|
protein_seq = st.text_area( |
|
|
"Enter protein sequence:", |
|
|
value="MKTIIALSYIFCLVFA", |
|
|
height=100 |
|
|
) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
|
if st.button("Analyze Protein"): |
|
|
seq = protein_seq.strip().upper() |
|
|
|
|
|
|
|
|
st.write(f"**Length:** {len(seq)}") |
|
|
st.write(f"**Unique AAs:** {len(set(seq))}") |
|
|
|
|
|
|
|
|
if TORCH_AVAILABLE and TRANSFORMERS_AVAILABLE: |
|
|
with st.spinner("Computing embedding..."): |
|
|
result = esm2_embed(seq, esm_model) |
|
|
if "error" in result: |
|
|
st.error(result["error"]) |
|
|
else: |
|
|
st.success(f"Embedding size: {result['size']}") |
|
|
st.json({"preview": result["embedding"][:5]}) |
|
|
else: |
|
|
st.warning("PyTorch not available for embeddings") |
|
|
|
|
|
with col2: |
|
|
st.info("Amino acid composition and structure prediction features coming soon") |
|
|
|
|
|
|
|
|
with tab3: |
|
|
st.subheader("𧬠DNA Analysis") |
|
|
|
|
|
dna_seq = st.text_area( |
|
|
"Enter DNA sequence:", |
|
|
value="ATGCGATCGTAGC", |
|
|
height=100 |
|
|
) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
|
if st.button("Analyze DNA"): |
|
|
seq = dna_seq.strip().upper() |
|
|
|
|
|
|
|
|
gc = (seq.count("G") + seq.count("C")) / len(seq) if seq else 0 |
|
|
|
|
|
st.write(f"**Length:** {len(seq)}") |
|
|
st.write(f"**GC Content:** {gc:.2%}") |
|
|
|
|
|
|
|
|
if TORCH_AVAILABLE and TRANSFORMERS_AVAILABLE: |
|
|
with st.spinner("Computing embedding..."): |
|
|
result = dna_embed(seq, dna_model) |
|
|
if "error" in result: |
|
|
st.error(result["error"]) |
|
|
else: |
|
|
st.success(f"Embedding size: {result['size']}") |
|
|
st.json({"preview": result["embedding"][:5]}) |
|
|
else: |
|
|
st.warning("PyTorch not available for embeddings") |
|
|
|
|
|
with col2: |
|
|
st.info("Motif analysis and structure prediction coming soon") |
|
|
|
|
|
|
|
|
with tab4: |
|
|
st.subheader("βΉοΈ About") |
|
|
st.markdown(""" |
|
|
### Features |
|
|
- π¬ RAG-based chat for bioinformatics questions |
|
|
- 𧬠Protein sequence analysis with ESM-2 |
|
|
- 𧬠DNA sequence analysis with DNABERT-2 |
|
|
- π Web search integration via Brave API |
|
|
- π File upload and vector search |
|
|
|
|
|
### Models |
|
|
- **Proteins:** ESM-2 (Facebook) |
|
|
- **DNA:** DNABERT-2 (Microsoft) |
|
|
- **LLM:** Llama 3.1 70B (via Fireworks) |
|
|
|
|
|
### Disclaimer |
|
|
This tool is for research and educational purposes only. |
|
|
Not for medical diagnosis or treatment decisions. |
|
|
""") |
|
|
|
|
|
|
|
|
st.divider() |
|
|
st.subheader("System Status") |
|
|
deps = { |
|
|
"PyTorch": TORCH_AVAILABLE, |
|
|
"Transformers": TRANSFORMERS_AVAILABLE, |
|
|
"Sentence Transformers": SENTENCE_TRANSFORMERS_AVAILABLE, |
|
|
"FAISS": FAISS_AVAILABLE, |
|
|
"BioPython": BIOPYTHON_AVAILABLE, |
|
|
"Datasets": DATASETS_AVAILABLE |
|
|
} |
|
|
|
|
|
for name, available in deps.items(): |
|
|
if available: |
|
|
st.success(f"β
{name}") |
|
|
else: |
|
|
st.warning(f"β οΈ {name} not available") |