Ashokreddy6's picture
Upload app.py with huggingface_hub
89000fe verified
import streamlit as st
import os, pickle
import numpy as np
import pandas as pd
import faiss
import torch
from sentence_transformers import SentenceTransformer
from rank_bm25 import BM25Okapi
from transformers import (
AutoTokenizer, AutoModelForSequenceClassification,
AutoModelForSeq2SeqLM
)
from huggingface_hub import hf_hub_download
# ======================================
# HUGGING FACE REPOSITORIES
# ======================================
DATA_REPO = "Ashokreddy6/CounselAssistantV2-data"
RETRIEVER_REPO = "Ashokreddy6/CounselAssistantV2-retriever"
RERANKER_REPO = "Ashokreddy6/CounselAssistantV2-reranker"
GEN_REPO = "Ashokreddy6/CounselAssistantV2-generator"
FAISS_REPO = "Ashokreddy6/CounselAssistantV2-faiss"
# ======================================
# STREAMLIT CONFIG
# ======================================
st.set_page_config(page_title="CounselAssistant", layout="wide")
st.title("💼 CounselAssistant - Financial & Legal Chatbot")
st.caption("Ask finance or legal questions from EDGAR, FIQA, CUAD, PhraseBank.")
# ======================================
# LOAD EVERYTHING
# ======================================
@st.cache_resource
def load_all():
# -----------------------------
# Load Corpus
# -----------------------------
corpus_path = hf_hub_download(
repo_id=DATA_REPO,
filename="corpus_merged_df.csv"
)
corpus_df = pd.read_csv(corpus_path)
corpus_df["text"] = corpus_df["text"].astype(str)
# -----------------------------
# Load BM25
# -----------------------------
try:
bm25_path = hf_hub_download(
repo_id=DATA_REPO,
filename="bm25.pkl"
)
with open(bm25_path, "rb") as f:
bm25 = pickle.load(f)
print("BM25 loaded from Hugging Face.")
except:
print("BM25 missing — rebuilding...")
bm25 = BM25Okapi([t.split() for t in corpus_df["text"].tolist()])
# -----------------------------
# Load Retriever (SentenceTransformer)
# -----------------------------
retriever = SentenceTransformer(RETRIEVER_REPO)
# -----------------------------
# Load FAISS Index
# -----------------------------
faiss_file = hf_hub_download(
repo_id=FAISS_REPO,
filename="passages.index"
)
index = faiss.read_index(faiss_file)
# -----------------------------
# Load Reranker Cross-Encoder
# -----------------------------
rr_tok = AutoTokenizer.from_pretrained(RERANKER_REPO)
rr_mod = AutoModelForSequenceClassification.from_pretrained(RERANKER_REPO)
# -----------------------------
# Load Generator (FLAN-T5)
# -----------------------------
gen_tok = AutoTokenizer.from_pretrained(GEN_REPO)
gen_mod = AutoModelForSeq2SeqLM.from_pretrained(GEN_REPO)
# -----------------------------
# Device
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
rr_mod.to(device)
gen_mod.to(device)
return corpus_df, retriever, index, bm25, rr_tok, rr_mod, gen_tok, gen_mod, device
with st.spinner("Loading models and indexes from Hugging Face..."):
(corpus_df, retriever, index, bm25,
rr_tok, rr_mod, gen_tok, gen_mod, device) = load_all()
# ======================================
# DOMAIN DETECTION
# ======================================
def detect_domain(query: str) -> str:
q = query.lower()
edgar_kw = ["revenue", "income", "earnings", "cash", "assets", "liabilities",
"sales", "debt", "expenses", "filing", "10-k", "risk", "report"]
phrasebank_kw = ["sentiment", "positive", "negative", "tone", "outlook"]
fiqa_kw = ["analyst", "forecast", "opinion", "recommendation", "qa", "investor"]
legal_kw = ["contract", "clause", "agreement", "liability",
"section", "breach", "court", "legal"]
if any(w in q for w in edgar_kw): return "edgar"
if any(w in q for w in phrasebank_kw): return "financial_phrasebank"
if any(w in q for w in fiqa_kw): return "fiqa"
if any(w in q for w in legal_kw): return "legal"
return "general"
# ======================================
# HYBRID SEARCH
# ======================================
def hybrid_search(query, top_k=10):
q_vec = retriever.encode(
[f"query: {query}"],
normalize_embeddings=True
).astype("float32")
# Dense search
D, I = index.search(q_vec, top_k)
# Sparse search (BM25)
bm_scores = bm25.get_scores(query.split())
bm_top = np.argsort(-bm_scores)[:top_k]
dense_rank = {int(i): float(D[0][pos]) for pos, i in enumerate(I[0])}
bm_rank = {int(i): float(bm_scores[i]) for i in bm_top}
def norm(scores):
if not scores:
return {}
vals = list(scores.values())
mn, mx = min(vals), max(vals)
return {k: (v - mn) / (mx - mn + 1e-8) for k, v in scores.items()}
d_norm = norm(dense_rank)
b_norm = norm(bm_rank)
merged = {
k: 0.6 * d_norm.get(k, 0) + 0.4 * b_norm.get(k, 0)
for k in set(d_norm) | set(bm_rank)
}
ranked = sorted(merged.items(), key=lambda x: x[1], reverse=True)
return [i for i, _ in ranked[:top_k]]
# ======================================
# RERANK
# ======================================
def rerank(query, ids, top_k=5):
texts = [corpus_df["text"].iloc[i] for i in ids]
enc = rr_tok(
[query] * len(texts),
texts,
padding=True,
truncation=True,
return_tensors="pt"
).to(device)
with torch.no_grad():
scores = rr_mod(**enc).logits.squeeze(-1)
order = torch.argsort(scores, descending=True).tolist()
return [ids[i] for i in order[:top_k]]
# ======================================
# CONTEXT BUILDER WITH CITATIONS
# ======================================
def build_context(ids, max_chars=1200):
ctx = ""
for rid in ids:
row = corpus_df.iloc[rid]
text = row["text"][:max_chars]
src = row.get("source", "Unknown")
docid = row.get("doc_id", rid)
ctx += f"\n---\n{text}\n(Citation: [{src}#{docid}])"
return ctx
# ======================================
# GENERATE ANSWER
# ======================================
def generate_answer(query, ids):
ctx = build_context(ids)
prompt = (
"You are a finance & legal assistant. "
"Answer concisely using ONLY the provided context.\n\n"
f"Context:\n{ctx}\n\n"
f"Question: {query}\nAnswer:"
)
inputs = gen_tok(prompt, return_tensors="pt", truncation=True).to(device)
output = gen_mod.generate(
**inputs,
max_new_tokens=256,
temperature=0.2
)
return gen_tok.decode(output[0], skip_special_tokens=True)
# ======================================
# MAIN PIPELINE
# ======================================
def ask(query):
domain = detect_domain(query)
ids = hybrid_search(query)
ids = rerank(query, ids)
ans = generate_answer(query, ids)
return domain, ans
# ======================================
# STREAMLIT UI
# ======================================
query = st.text_area(
"💬 Ask your financial or legal question:",
placeholder="E.g., 'Which clauses discuss liability in this agreement?'"
)
if st.button("Ask"):
if not query.strip():
st.warning("Please enter a question")
else:
with st.spinner("Thinking..."):
domain, answer = ask(query)
st.success(f"**Detected Domain:** {domain}")
st.write("### 🧠 Answer")
st.write(answer)
st.markdown("---")
st.caption("CounselAssistant © 2025 | Finance + Legal RAG Chatbot")