import streamlit as st import os, pickle import numpy as np import pandas as pd import faiss import torch from sentence_transformers import SentenceTransformer from rank_bm25 import BM25Okapi from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM ) from huggingface_hub import hf_hub_download # ====================================== # HUGGING FACE REPOSITORIES # ====================================== DATA_REPO = "Ashokreddy6/CounselAssistantV2-data" RETRIEVER_REPO = "Ashokreddy6/CounselAssistantV2-retriever" RERANKER_REPO = "Ashokreddy6/CounselAssistantV2-reranker" GEN_REPO = "Ashokreddy6/CounselAssistantV2-generator" FAISS_REPO = "Ashokreddy6/CounselAssistantV2-faiss" # ====================================== # STREAMLIT CONFIG # ====================================== st.set_page_config(page_title="CounselAssistant", layout="wide") st.title("💼 CounselAssistant - Financial & Legal Chatbot") st.caption("Ask finance or legal questions from EDGAR, FIQA, CUAD, PhraseBank.") # ====================================== # LOAD EVERYTHING # ====================================== @st.cache_resource def load_all(): # ----------------------------- # Load Corpus # ----------------------------- corpus_path = hf_hub_download( repo_id=DATA_REPO, filename="corpus_merged_df.csv" ) corpus_df = pd.read_csv(corpus_path) corpus_df["text"] = corpus_df["text"].astype(str) # ----------------------------- # Load BM25 # ----------------------------- try: bm25_path = hf_hub_download( repo_id=DATA_REPO, filename="bm25.pkl" ) with open(bm25_path, "rb") as f: bm25 = pickle.load(f) print("BM25 loaded from Hugging Face.") except: print("BM25 missing — rebuilding...") bm25 = BM25Okapi([t.split() for t in corpus_df["text"].tolist()]) # ----------------------------- # Load Retriever (SentenceTransformer) # ----------------------------- retriever = SentenceTransformer(RETRIEVER_REPO) # ----------------------------- # Load FAISS Index # ----------------------------- faiss_file = hf_hub_download( repo_id=FAISS_REPO, filename="passages.index" ) index = faiss.read_index(faiss_file) # ----------------------------- # Load Reranker Cross-Encoder # ----------------------------- rr_tok = AutoTokenizer.from_pretrained(RERANKER_REPO) rr_mod = AutoModelForSequenceClassification.from_pretrained(RERANKER_REPO) # ----------------------------- # Load Generator (FLAN-T5) # ----------------------------- gen_tok = AutoTokenizer.from_pretrained(GEN_REPO) gen_mod = AutoModelForSeq2SeqLM.from_pretrained(GEN_REPO) # ----------------------------- # Device # ----------------------------- device = "cuda" if torch.cuda.is_available() else "cpu" rr_mod.to(device) gen_mod.to(device) return corpus_df, retriever, index, bm25, rr_tok, rr_mod, gen_tok, gen_mod, device with st.spinner("Loading models and indexes from Hugging Face..."): (corpus_df, retriever, index, bm25, rr_tok, rr_mod, gen_tok, gen_mod, device) = load_all() # ====================================== # DOMAIN DETECTION # ====================================== def detect_domain(query: str) -> str: q = query.lower() edgar_kw = ["revenue", "income", "earnings", "cash", "assets", "liabilities", "sales", "debt", "expenses", "filing", "10-k", "risk", "report"] phrasebank_kw = ["sentiment", "positive", "negative", "tone", "outlook"] fiqa_kw = ["analyst", "forecast", "opinion", "recommendation", "qa", "investor"] legal_kw = ["contract", "clause", "agreement", "liability", "section", "breach", "court", "legal"] if any(w in q for w in edgar_kw): return "edgar" if any(w in q for w in phrasebank_kw): return "financial_phrasebank" if any(w in q for w in fiqa_kw): return "fiqa" if any(w in q for w in legal_kw): return "legal" return "general" # ====================================== # HYBRID SEARCH # ====================================== def hybrid_search(query, top_k=10): q_vec = retriever.encode( [f"query: {query}"], normalize_embeddings=True ).astype("float32") # Dense search D, I = index.search(q_vec, top_k) # Sparse search (BM25) bm_scores = bm25.get_scores(query.split()) bm_top = np.argsort(-bm_scores)[:top_k] dense_rank = {int(i): float(D[0][pos]) for pos, i in enumerate(I[0])} bm_rank = {int(i): float(bm_scores[i]) for i in bm_top} def norm(scores): if not scores: return {} vals = list(scores.values()) mn, mx = min(vals), max(vals) return {k: (v - mn) / (mx - mn + 1e-8) for k, v in scores.items()} d_norm = norm(dense_rank) b_norm = norm(bm_rank) merged = { k: 0.6 * d_norm.get(k, 0) + 0.4 * b_norm.get(k, 0) for k in set(d_norm) | set(bm_rank) } ranked = sorted(merged.items(), key=lambda x: x[1], reverse=True) return [i for i, _ in ranked[:top_k]] # ====================================== # RERANK # ====================================== def rerank(query, ids, top_k=5): texts = [corpus_df["text"].iloc[i] for i in ids] enc = rr_tok( [query] * len(texts), texts, padding=True, truncation=True, return_tensors="pt" ).to(device) with torch.no_grad(): scores = rr_mod(**enc).logits.squeeze(-1) order = torch.argsort(scores, descending=True).tolist() return [ids[i] for i in order[:top_k]] # ====================================== # CONTEXT BUILDER WITH CITATIONS # ====================================== def build_context(ids, max_chars=1200): ctx = "" for rid in ids: row = corpus_df.iloc[rid] text = row["text"][:max_chars] src = row.get("source", "Unknown") docid = row.get("doc_id", rid) ctx += f"\n---\n{text}\n(Citation: [{src}#{docid}])" return ctx # ====================================== # GENERATE ANSWER # ====================================== def generate_answer(query, ids): ctx = build_context(ids) prompt = ( "You are a finance & legal assistant. " "Answer concisely using ONLY the provided context.\n\n" f"Context:\n{ctx}\n\n" f"Question: {query}\nAnswer:" ) inputs = gen_tok(prompt, return_tensors="pt", truncation=True).to(device) output = gen_mod.generate( **inputs, max_new_tokens=256, temperature=0.2 ) return gen_tok.decode(output[0], skip_special_tokens=True) # ====================================== # MAIN PIPELINE # ====================================== def ask(query): domain = detect_domain(query) ids = hybrid_search(query) ids = rerank(query, ids) ans = generate_answer(query, ids) return domain, ans # ====================================== # STREAMLIT UI # ====================================== query = st.text_area( "💬 Ask your financial or legal question:", placeholder="E.g., 'Which clauses discuss liability in this agreement?'" ) if st.button("Ask"): if not query.strip(): st.warning("Please enter a question") else: with st.spinner("Thinking..."): domain, answer = ask(query) st.success(f"**Detected Domain:** {domain}") st.write("### 🧠 Answer") st.write(answer) st.markdown("---") st.caption("CounselAssistant © 2025 | Finance + Legal RAG Chatbot")