test / src /streamlit_app.py
omarkashif's picture
Update src/streamlit_app.py
f6d9e3b verified
import os
os.environ.setdefault("HF_HOME", "/tmp/huggingface")
os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/huggingface")
os.environ.setdefault("HF_HUB_CACHE", "/tmp/huggingface")
os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", "/tmp/huggingface/st_models")
import streamlit as st
import openai
from collections import deque
from sentence_transformers import SentenceTransformer
from pinecone import Pinecone
import re
# Setup (exact hardcoded keys you provided)
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
index = pc.Index("legal-ai")
model = SentenceTransformer('all-mpnet-base-v2')
chat_history = deque(maxlen=10) # last 5 pairs = 10 messages
ll_model = 'gpt-4o-mini'
st.title("AI Legal Assistant โš–๏ธ")
if "history" not in st.session_state:
st.session_state.history = deque(maxlen=10)
def get_rewritten_query(user_query):
hist = list(st.session_state.history)[-4:]
hist_text = "\n".join(f"{m['role']}: {m['content']}" for m in hist)
messages = [
{"role": "system", "content":
"You are a legal assistant that rewrites user queries into clear, context-aware queries for vector DB lookup. If its already clear then dont rewite"},
{"role": "user", "content":
f"History:\n{hist_text}\n\nNew query:\n{user_query}\n\n"
"Rewrite if needed for clarity/search purposes. Otherwise, repeat exactly."}
]
try:
resp = client.chat.completions.create(
model=ll_model,
messages=messages,
temperature=0.1,
max_tokens=400
)
rewritten = resp.choices[0].message.content.strip()
except Exception as e:
st.error(f"Rewrite error: {e}")
rewritten = user_query
# st.session_state.history.append({"role": "assistant", "content": f"๐Ÿ” Rewritten query: {rewritten}"})
return rewritten
def retrieve_documents(query, top_k=10):
emb = model.encode(query).tolist()
try:
return index.query(vector=emb, top_k=top_k, include_metadata=True)['matches']
except Exception as e:
st.error(f"Retrieve error: {e}")
return []
def clean_chunk_id(cid: str) -> str:
"""Beautify chunk_id by replacing underscores/dashes with spaces and capitalizing words."""
# Remove any trailing '_chunk_xxx' stuff
cid = re.sub(r'_chunk.*$', '', cid)
# Replace _ and - with spaces
cid = cid.replace("_", " ").replace("-", " ")
# Capitalize each word
cid = " ".join(word.capitalize() for word in cid.split())
return cid
def generate_response(user_query, docs):
# --- Collect context ---
context = "\n\n---\n\n".join(d['metadata']['text'] for d in docs)
# --- Build human-friendly sources + mapping ---
source_links = {}
for d in docs:
meta = d['metadata']
src = meta.get("source", "unknown").lower()
cid = meta.get("chunk_id", "")
text_preview = " ".join(meta.get("text", "").split()[:30])
if src in ["constitution"]:
display_name = f"Constitution ({clean_chunk_id(cid)})"
elif src in ["fbr_ordinance", "ordinance", "tax_ordinance"]:
display_name = f"Tax Ordinance ({clean_chunk_id(cid)})"
elif src in ["case_law", "case", "tax_case"]:
display_name = f"Case Law: {text_preview}..."
else:
display_name = f"{src.title()} ({clean_chunk_id(cid)})"
source_links[display_name] = meta.get("text", "")
# Deduplicate
source_links = dict(sorted(source_links.items()))
# --- System prompt ---
messages = [
{"role": "system", "content":
"You are a helpful legal assistant. Use the provided context from documents to answer the user's question. "
"At the end of your answer, write a single line starting with 'Source: ' followed by the sources used. "
"Formatting rules:\n"
"- For Constitution / Ordinances: show the clean chunk id, no underscores/dashes, capitalized words.\n"
"- For Case law: ignore chunk id, instead show first ~30 words of the case text.\n"
"- Do not use technical terms like 'chunk'. Present sources in a human-friendly way.\n"
"If multiple are used, separate them with commas."}
]
messages.extend(st.session_state.history)
messages.append({"role": "user", "content": f"Context:\n{context}\n\n"
f"Sources:\n{', '.join(source_links.keys())}\n\n"
f"Question:\n{user_query}"})
try:
resp = client.chat.completions.create(
model=ll_model,
messages=messages,
temperature=0.1,
max_tokens=900
)
reply = resp.choices[0].message.content.strip()
except Exception as e:
st.error(f"Response error: {e}")
reply = "Sorry, I encountered an error generating the answer."
# Optional: force clean source line if LLM misses it
if source_links:
clean_sources = ", ".join(source_links.keys())
if "Source:" not in reply:
reply += f"\n\nSource: {clean_sources}"
# Save reply into history
st.session_state.history.append({"role": "assistant", "content": reply})
# --- Render in Streamlit ---
st.markdown(reply)
# Add expandable sources
if source_links:
st.write("### Sources")
for name, text in source_links.items():
with st.expander(name):
st.write(text)
return reply
# Chat UI
with st.form("chat_input", clear_on_submit=True):
user_input = st.text_input("You:", "")
submit = st.form_submit_button("Send")
if submit and user_input:
st.session_state.history.append({"role": "user", "content": user_input})
rewritten = get_rewritten_query(user_input)
docs = retrieve_documents(rewritten)
assistant_reply = generate_response(rewritten, docs)
c = 0
# Display history
st.markdown("---")
for msg in reversed(st.session_state.history):
c+=1
if msg["role"] == "user":
st.markdown(f"**You:** {msg['content']}")
else:
st.markdown(f"**Legal Assistant:** {msg['content']}")
if c ^ 1: st.markdown("---")