Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os, pickle | |
| import numpy as np | |
| import pandas as pd | |
| import faiss | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| from rank_bm25 import BM25Okapi | |
| from transformers import ( | |
| AutoTokenizer, AutoModelForSequenceClassification, | |
| AutoModelForSeq2SeqLM | |
| ) | |
| from huggingface_hub import hf_hub_download | |
| # ====================================== | |
| # HUGGING FACE REPOSITORIES | |
| # ====================================== | |
| DATA_REPO = "Ashokreddy6/CounselAssistantV2-data" | |
| RETRIEVER_REPO = "Ashokreddy6/CounselAssistantV2-retriever" | |
| RERANKER_REPO = "Ashokreddy6/CounselAssistantV2-reranker" | |
| GEN_REPO = "Ashokreddy6/CounselAssistantV2-generator" | |
| FAISS_REPO = "Ashokreddy6/CounselAssistantV2-faiss" | |
| # ====================================== | |
| # STREAMLIT CONFIG | |
| # ====================================== | |
| st.set_page_config(page_title="CounselAssistant", layout="wide") | |
| st.title("💼 CounselAssistant - Financial & Legal Chatbot") | |
| st.caption("Ask finance or legal questions from EDGAR, FIQA, CUAD, PhraseBank.") | |
| # ====================================== | |
| # LOAD EVERYTHING | |
| # ====================================== | |
| def load_all(): | |
| # ----------------------------- | |
| # Load Corpus | |
| # ----------------------------- | |
| corpus_path = hf_hub_download( | |
| repo_id=DATA_REPO, | |
| filename="corpus_merged_df.csv" | |
| ) | |
| corpus_df = pd.read_csv(corpus_path) | |
| corpus_df["text"] = corpus_df["text"].astype(str) | |
| # ----------------------------- | |
| # Load BM25 | |
| # ----------------------------- | |
| try: | |
| bm25_path = hf_hub_download( | |
| repo_id=DATA_REPO, | |
| filename="bm25.pkl" | |
| ) | |
| with open(bm25_path, "rb") as f: | |
| bm25 = pickle.load(f) | |
| print("BM25 loaded from Hugging Face.") | |
| except: | |
| print("BM25 missing — rebuilding...") | |
| bm25 = BM25Okapi([t.split() for t in corpus_df["text"].tolist()]) | |
| # ----------------------------- | |
| # Load Retriever (SentenceTransformer) | |
| # ----------------------------- | |
| retriever = SentenceTransformer(RETRIEVER_REPO) | |
| # ----------------------------- | |
| # Load FAISS Index | |
| # ----------------------------- | |
| faiss_file = hf_hub_download( | |
| repo_id=FAISS_REPO, | |
| filename="passages.index" | |
| ) | |
| index = faiss.read_index(faiss_file) | |
| # ----------------------------- | |
| # Load Reranker Cross-Encoder | |
| # ----------------------------- | |
| rr_tok = AutoTokenizer.from_pretrained(RERANKER_REPO) | |
| rr_mod = AutoModelForSequenceClassification.from_pretrained(RERANKER_REPO) | |
| # ----------------------------- | |
| # Load Generator (FLAN-T5) | |
| # ----------------------------- | |
| gen_tok = AutoTokenizer.from_pretrained(GEN_REPO) | |
| gen_mod = AutoModelForSeq2SeqLM.from_pretrained(GEN_REPO) | |
| # ----------------------------- | |
| # Device | |
| # ----------------------------- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| rr_mod.to(device) | |
| gen_mod.to(device) | |
| return corpus_df, retriever, index, bm25, rr_tok, rr_mod, gen_tok, gen_mod, device | |
| with st.spinner("Loading models and indexes from Hugging Face..."): | |
| (corpus_df, retriever, index, bm25, | |
| rr_tok, rr_mod, gen_tok, gen_mod, device) = load_all() | |
| # ====================================== | |
| # DOMAIN DETECTION | |
| # ====================================== | |
| def detect_domain(query: str) -> str: | |
| q = query.lower() | |
| edgar_kw = ["revenue", "income", "earnings", "cash", "assets", "liabilities", | |
| "sales", "debt", "expenses", "filing", "10-k", "risk", "report"] | |
| phrasebank_kw = ["sentiment", "positive", "negative", "tone", "outlook"] | |
| fiqa_kw = ["analyst", "forecast", "opinion", "recommendation", "qa", "investor"] | |
| legal_kw = ["contract", "clause", "agreement", "liability", | |
| "section", "breach", "court", "legal"] | |
| if any(w in q for w in edgar_kw): return "edgar" | |
| if any(w in q for w in phrasebank_kw): return "financial_phrasebank" | |
| if any(w in q for w in fiqa_kw): return "fiqa" | |
| if any(w in q for w in legal_kw): return "legal" | |
| return "general" | |
| # ====================================== | |
| # HYBRID SEARCH | |
| # ====================================== | |
| def hybrid_search(query, top_k=10): | |
| q_vec = retriever.encode( | |
| [f"query: {query}"], | |
| normalize_embeddings=True | |
| ).astype("float32") | |
| # Dense search | |
| D, I = index.search(q_vec, top_k) | |
| # Sparse search (BM25) | |
| bm_scores = bm25.get_scores(query.split()) | |
| bm_top = np.argsort(-bm_scores)[:top_k] | |
| dense_rank = {int(i): float(D[0][pos]) for pos, i in enumerate(I[0])} | |
| bm_rank = {int(i): float(bm_scores[i]) for i in bm_top} | |
| def norm(scores): | |
| if not scores: | |
| return {} | |
| vals = list(scores.values()) | |
| mn, mx = min(vals), max(vals) | |
| return {k: (v - mn) / (mx - mn + 1e-8) for k, v in scores.items()} | |
| d_norm = norm(dense_rank) | |
| b_norm = norm(bm_rank) | |
| merged = { | |
| k: 0.6 * d_norm.get(k, 0) + 0.4 * b_norm.get(k, 0) | |
| for k in set(d_norm) | set(bm_rank) | |
| } | |
| ranked = sorted(merged.items(), key=lambda x: x[1], reverse=True) | |
| return [i for i, _ in ranked[:top_k]] | |
| # ====================================== | |
| # RERANK | |
| # ====================================== | |
| def rerank(query, ids, top_k=5): | |
| texts = [corpus_df["text"].iloc[i] for i in ids] | |
| enc = rr_tok( | |
| [query] * len(texts), | |
| texts, | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt" | |
| ).to(device) | |
| with torch.no_grad(): | |
| scores = rr_mod(**enc).logits.squeeze(-1) | |
| order = torch.argsort(scores, descending=True).tolist() | |
| return [ids[i] for i in order[:top_k]] | |
| # ====================================== | |
| # CONTEXT BUILDER WITH CITATIONS | |
| # ====================================== | |
| def build_context(ids, max_chars=1200): | |
| ctx = "" | |
| for rid in ids: | |
| row = corpus_df.iloc[rid] | |
| text = row["text"][:max_chars] | |
| src = row.get("source", "Unknown") | |
| docid = row.get("doc_id", rid) | |
| ctx += f"\n---\n{text}\n(Citation: [{src}#{docid}])" | |
| return ctx | |
| # ====================================== | |
| # GENERATE ANSWER | |
| # ====================================== | |
| def generate_answer(query, ids): | |
| ctx = build_context(ids) | |
| prompt = ( | |
| "You are a finance & legal assistant. " | |
| "Answer concisely using ONLY the provided context.\n\n" | |
| f"Context:\n{ctx}\n\n" | |
| f"Question: {query}\nAnswer:" | |
| ) | |
| inputs = gen_tok(prompt, return_tensors="pt", truncation=True).to(device) | |
| output = gen_mod.generate( | |
| **inputs, | |
| max_new_tokens=256, | |
| temperature=0.2 | |
| ) | |
| return gen_tok.decode(output[0], skip_special_tokens=True) | |
| # ====================================== | |
| # MAIN PIPELINE | |
| # ====================================== | |
| def ask(query): | |
| domain = detect_domain(query) | |
| ids = hybrid_search(query) | |
| ids = rerank(query, ids) | |
| ans = generate_answer(query, ids) | |
| return domain, ans | |
| # ====================================== | |
| # STREAMLIT UI | |
| # ====================================== | |
| query = st.text_area( | |
| "💬 Ask your financial or legal question:", | |
| placeholder="E.g., 'Which clauses discuss liability in this agreement?'" | |
| ) | |
| if st.button("Ask"): | |
| if not query.strip(): | |
| st.warning("Please enter a question") | |
| else: | |
| with st.spinner("Thinking..."): | |
| domain, answer = ask(query) | |
| st.success(f"**Detected Domain:** {domain}") | |
| st.write("### 🧠 Answer") | |
| st.write(answer) | |
| st.markdown("---") | |
| st.caption("CounselAssistant © 2025 | Finance + Legal RAG Chatbot") | |