Spaces:
Sleeping
Sleeping
File size: 6,201 Bytes
f97322b ed14dc4 6640849 59a9179 f97322b 8d16824 ed14dc4 f97322b 8ca6217 f97322b f6d9e3b f97322b 1bd789c f97322b f6d9e3b f97322b b832b0a f97322b 1bd789c f97322b cf2f7c8 f97322b 8d16824 f97322b f6d9e3b f97322b f6d9e3b 408354f 8d16824 f6d9e3b 408354f 8d16824 f6d9e3b 408354f 8d16824 f6d9e3b 408354f f6d9e3b 408354f f6d9e3b 8d16824 113896e 8d16824 113896e 408354f f97322b f6d9e3b 8d28ad2 f6d9e3b 8d28ad2 f97322b f6d9e3b f97322b b832b0a f97322b dc064f8 113896e f6d9e3b 113896e cf2f7c8 f6d9e3b f97322b f6d9e3b f97322b dc064f8 113896e f6d9e3b f97322b cbb87db f97322b 1bd789c cbb87db f97322b 8b3e0c0 cbb87db | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | import os
os.environ.setdefault("HF_HOME", "/tmp/huggingface")
os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/huggingface")
os.environ.setdefault("HF_HUB_CACHE", "/tmp/huggingface")
os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", "/tmp/huggingface/st_models")
import streamlit as st
import openai
from collections import deque
from sentence_transformers import SentenceTransformer
from pinecone import Pinecone
import re
# Setup (exact hardcoded keys you provided)
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
index = pc.Index("legal-ai")
model = SentenceTransformer('all-mpnet-base-v2')
chat_history = deque(maxlen=10) # last 5 pairs = 10 messages
ll_model = 'gpt-4o-mini'
st.title("AI Legal Assistant ⚖️")
if "history" not in st.session_state:
st.session_state.history = deque(maxlen=10)
def get_rewritten_query(user_query):
hist = list(st.session_state.history)[-4:]
hist_text = "\n".join(f"{m['role']}: {m['content']}" for m in hist)
messages = [
{"role": "system", "content":
"You are a legal assistant that rewrites user queries into clear, context-aware queries for vector DB lookup. If its already clear then dont rewite"},
{"role": "user", "content":
f"History:\n{hist_text}\n\nNew query:\n{user_query}\n\n"
"Rewrite if needed for clarity/search purposes. Otherwise, repeat exactly."}
]
try:
resp = client.chat.completions.create(
model=ll_model,
messages=messages,
temperature=0.1,
max_tokens=400
)
rewritten = resp.choices[0].message.content.strip()
except Exception as e:
st.error(f"Rewrite error: {e}")
rewritten = user_query
# st.session_state.history.append({"role": "assistant", "content": f"🔁 Rewritten query: {rewritten}"})
return rewritten
def retrieve_documents(query, top_k=10):
emb = model.encode(query).tolist()
try:
return index.query(vector=emb, top_k=top_k, include_metadata=True)['matches']
except Exception as e:
st.error(f"Retrieve error: {e}")
return []
def clean_chunk_id(cid: str) -> str:
"""Beautify chunk_id by replacing underscores/dashes with spaces and capitalizing words."""
# Remove any trailing '_chunk_xxx' stuff
cid = re.sub(r'_chunk.*$', '', cid)
# Replace _ and - with spaces
cid = cid.replace("_", " ").replace("-", " ")
# Capitalize each word
cid = " ".join(word.capitalize() for word in cid.split())
return cid
def generate_response(user_query, docs):
# --- Collect context ---
context = "\n\n---\n\n".join(d['metadata']['text'] for d in docs)
# --- Build human-friendly sources + mapping ---
source_links = {}
for d in docs:
meta = d['metadata']
src = meta.get("source", "unknown").lower()
cid = meta.get("chunk_id", "")
text_preview = " ".join(meta.get("text", "").split()[:30])
if src in ["constitution"]:
display_name = f"Constitution ({clean_chunk_id(cid)})"
elif src in ["fbr_ordinance", "ordinance", "tax_ordinance"]:
display_name = f"Tax Ordinance ({clean_chunk_id(cid)})"
elif src in ["case_law", "case", "tax_case"]:
display_name = f"Case Law: {text_preview}..."
else:
display_name = f"{src.title()} ({clean_chunk_id(cid)})"
source_links[display_name] = meta.get("text", "")
# Deduplicate
source_links = dict(sorted(source_links.items()))
# --- System prompt ---
messages = [
{"role": "system", "content":
"You are a helpful legal assistant. Use the provided context from documents to answer the user's question. "
"At the end of your answer, write a single line starting with 'Source: ' followed by the sources used. "
"Formatting rules:\n"
"- For Constitution / Ordinances: show the clean chunk id, no underscores/dashes, capitalized words.\n"
"- For Case law: ignore chunk id, instead show first ~30 words of the case text.\n"
"- Do not use technical terms like 'chunk'. Present sources in a human-friendly way.\n"
"If multiple are used, separate them with commas."}
]
messages.extend(st.session_state.history)
messages.append({"role": "user", "content": f"Context:\n{context}\n\n"
f"Sources:\n{', '.join(source_links.keys())}\n\n"
f"Question:\n{user_query}"})
try:
resp = client.chat.completions.create(
model=ll_model,
messages=messages,
temperature=0.1,
max_tokens=900
)
reply = resp.choices[0].message.content.strip()
except Exception as e:
st.error(f"Response error: {e}")
reply = "Sorry, I encountered an error generating the answer."
# Optional: force clean source line if LLM misses it
if source_links:
clean_sources = ", ".join(source_links.keys())
if "Source:" not in reply:
reply += f"\n\nSource: {clean_sources}"
# Save reply into history
st.session_state.history.append({"role": "assistant", "content": reply})
# --- Render in Streamlit ---
st.markdown(reply)
# Add expandable sources
if source_links:
st.write("### Sources")
for name, text in source_links.items():
with st.expander(name):
st.write(text)
return reply
# Chat UI
with st.form("chat_input", clear_on_submit=True):
user_input = st.text_input("You:", "")
submit = st.form_submit_button("Send")
if submit and user_input:
st.session_state.history.append({"role": "user", "content": user_input})
rewritten = get_rewritten_query(user_input)
docs = retrieve_documents(rewritten)
assistant_reply = generate_response(rewritten, docs)
c = 0
# Display history
st.markdown("---")
for msg in reversed(st.session_state.history):
c+=1
if msg["role"] == "user":
st.markdown(f"**You:** {msg['content']}")
else:
st.markdown(f"**Legal Assistant:** {msg['content']}")
if c ^ 1: st.markdown("---") |