test / src /streamlit_app.py
omarkashif's picture
Update src/streamlit_app.py
51eb084 verified
raw
history blame
10.6 kB
import os
os.environ.setdefault("HF_HOME", "/tmp/huggingface")
os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/huggingface")
os.environ.setdefault("HF_HUB_CACHE", "/tmp/huggingface")
os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", "/tmp/huggingface/st_models")
import streamlit as st
import openai
import psycopg2
from collections import deque
from sentence_transformers import SentenceTransformer
import re
# Setup
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
ll_model = 'gpt-4o-mini'
# ── NEW: PostgreSQL connection ──────────────────────────────
def get_db_connection():
return psycopg2.connect(
host=os.getenv("RDS_HOST"),
port=os.getenv("RDS_PORT", 5432),
dbname=os.getenv("RDS_DB"),
user=os.getenv("RDS_USER"),
password=os.getenv("RDS_PASS")
)
# ── NEW: BGE model ──────────────────────────────────────────
model = SentenceTransformer('BAAI/bge-small-en-v1.5')
def retrieve_summaries(query, top_k=40):
try:
embedding = get_embedding(query)
conn = get_db_connection()
cur = conn.cursor()
cur.execute("""
SELECT
id,
case_id,
chunk_index,
chunk_summary,
1 - (embedding <=> %s::vector) AS similarity
FROM public.case_chunks
ORDER BY embedding <=> %s::vector
LIMIT %s;
""", [embedding, embedding, top_k])
rows = cur.fetchall()
cur.close()
conn.close()
return [
{
"id": row[0],
"case_id": row[1],
"chunk_index": row[2],
"chunk_summary": row[3],
"similarity": row[4]
}
for row in rows
]
except Exception as e:
st.error(f"Retrieve error: {e}")
return []
# ── STEP 2: LLM picks best chunks based on summaries ───────
def rerank_with_llm(query, candidates, final_k=10):
summary_list = "\n".join([
f"[ID: {c['id']}] Case: {c['case_id']} | Summary: {c['chunk_summary']}"
for c in candidates
])
messages = [
{"role": "system", "content":
"You are a legal research assistant. Given a user query and a list of document chunk summaries, "
"select the most relevant chunk IDs that would best answer the query. "
"Return ONLY a comma-separated list of IDs, nothing else. Example: 12,45,67,23"
},
{"role": "user", "content":
f"Query: {query}\n\n"
f"Chunks:\n{summary_list}\n\n"
f"Select the {final_k} most relevant chunk IDs."
}
]
try:
resp = client.chat.completions.create(
model=ll_model,
messages=messages,
temperature=0.0,
max_tokens=200
)
raw = resp.choices[0].message.content.strip()
selected_ids = [int(i.strip()) for i in raw.split(",") if i.strip().isdigit()]
return selected_ids[:final_k]
except Exception as e:
st.error(f"Rerank error: {e}")
# Fallback: just return top final_k by similarity
return [c["id"] for c in candidates[:final_k]]
# ── STEP 3: fetch full chunk_text for selected IDs only ────
def fetch_chunks_by_ids(selected_ids):
try:
conn = get_db_connection()
cur = conn.cursor()
cur.execute("""
SELECT
id,
case_id,
chunk_index,
chunk_text,
chunk_summary
FROM public.case_chunks
WHERE id = ANY(%s);
""", [selected_ids])
rows = cur.fetchall()
cur.close()
conn.close()
return [
{
"id": row[0],
"case_id": row[1],
"chunk_index": row[2],
"chunk_text": row[3],
"chunk_summary": row[4]
}
for row in rows
]
except Exception as e:
st.error(f"Fetch error: {e}")
return []
def get_embedding(text):
# BGE requires this prefix for queries
prefixed = f"Represent this sentence for searching relevant passages: {text}"
return model.encode(prefixed).tolist()
st.title("AI Legal Assistant βš–οΈ")
if "history" not in st.session_state:
st.session_state.history = deque(maxlen=10)
def get_rewritten_query(user_query):
hist = list(st.session_state.history)[-4:]
hist_text = "\n".join(f"{m['role']}: {m['content']}" for m in hist)
messages = [
{"role": "system", "content":
"You are a legal assistant that rewrites user queries into clear, context-aware queries for vector DB lookup. If its already clear then dont rewrite"},
{"role": "user", "content":
f"History:\n{hist_text}\n\nNew query:\n{user_query}\n\n"
"Rewrite if needed for clarity/search purposes. Otherwise, repeat exactly."}
]
try:
resp = client.chat.completions.create(
model=ll_model,
messages=messages,
temperature=0.1,
max_tokens=400
)
rewritten = resp.choices[0].message.content.strip()
except Exception as e:
st.error(f"Rewrite error: {e}")
rewritten = user_query
return rewritten
# ── UPDATED: retrieve from pgvector ────────────────────────
# def retrieve_documents(query, top_k=10):
# try:
# embedding = get_embedding(query)
# conn = get_db_connection()
# cur = conn.cursor()
# cur.execute("""
# SELECT
# case_id,
# chunk_index,
# chunk_text,
# chunk_summary,
# 1 - (embedding <=> %s::vector) AS similarity
# FROM public.case_chunks
# ORDER BY embedding <=> %s::vector
# LIMIT %s;
# """, [embedding, embedding, top_k])
# rows = cur.fetchall()
# cur.close()
# conn.close()
# # Format to match the rest of the app
# docs = []
# for row in rows:
# docs.append({
# "case_id": row[0],
# "chunk_index": row[1],
# "chunk_text": row[2],
# "chunk_summary": row[3],
# "similarity": row[4]
# })
# return docs
# ── COMBINED: full retrieval pipeline ──────────────────────
def retrieve_documents(query, top_k=10):
# 1. Get 4x summaries
candidates = retrieve_summaries(query, top_k=top_k * 4)
if not candidates:
return []
# 2. LLM picks best IDs from summaries
selected_ids = rerank_with_llm(query, candidates, final_k=top_k)
if not selected_ids:
return []
# 3. Fetch full text for selected chunks only
docs = fetch_chunks_by_ids(selected_ids)
return docs
except Exception as e:
st.error(f"Retrieve error: {e}")
return []
def clean_chunk_id(cid: str) -> str:
cid = re.sub(r'_chunk.*$', '', cid)
cid = cid.replace("_", " ").replace("-", " ")
cid = " ".join(word.capitalize() for word in cid.split())
return cid
# ── UPDATED: generate response with new doc structure ───────
def generate_response(user_query, docs):
# Collect context from chunk_text
context = "\n\n---\n\n".join(d['chunk_text'] for d in docs if d['chunk_text'])
# Build sources
source_links = {}
for d in docs:
case_id = d.get("case_id", "unknown")
chunk_idx = d.get("chunk_index", "")
text_preview = " ".join((d.get("chunk_text") or "").split()[:30])
if case_id == "constitution":
display_name = f"Constitution (Chunk {chunk_idx})"
else:
display_name = f"Case Law: {text_preview}..."
source_links[display_name] = d.get("chunk_text", "")
source_links = dict(sorted(source_links.items()))
messages = [
{"role": "system", "content":
"You are a helpful legal assistant. Use the provided context from documents to answer the user's question. "
"At the end of your answer, write a single line starting with 'Source: ' followed by the sources used. "
"Formatting rules:\n"
"- For Constitution: show the chunk number.\n"
"- For Case law: show first ~30 words of the case text.\n"
"- Do not use technical terms like 'chunk'. Present sources in a human-friendly way.\n"
"If multiple are used, separate them with commas."}
]
messages.extend(list(st.session_state.history))
messages.append({"role": "user", "content": f"Context:\n{context}\n\n"
f"Sources:\n{', '.join(source_links.keys())}\n\n"
f"Question:\n{user_query}"})
try:
resp = client.chat.completions.create(
model=ll_model,
messages=messages,
temperature=0.1,
max_tokens=900
)
reply = resp.choices[0].message.content.strip()
except Exception as e:
st.error(f"Response error: {e}")
reply = "Sorry, I encountered an error generating the answer."
if source_links:
clean_sources = ", ".join(source_links.keys())
if "Source:" not in reply:
reply += f"\n\nSource: {clean_sources}"
st.session_state.history.append({"role": "assistant", "content": reply})
st.markdown(reply)
if source_links:
st.write("### Sources")
for name, text in source_links.items():
with st.expander(name):
st.write(text)
return reply
# Chat UI
with st.form("chat_input", clear_on_submit=True):
user_input = st.text_input("You:", "")
submit = st.form_submit_button("Send")
if submit and user_input:
st.session_state.history.append({"role": "user", "content": user_input})
rewritten = get_rewritten_query(user_input)
docs = retrieve_documents(rewritten)
assistant_reply = generate_response(rewritten, docs)
c = 0
st.markdown("---")
for msg in reversed(st.session_state.history):
c += 1
if msg["role"] == "user":
st.markdown(f"**You:** {msg['content']}")
else:
st.markdown(f"**Legal Assistant:** {msg['content']}")
if c ^ 1:
st.markdown("---")