import os os.environ.setdefault("HF_HOME", "/tmp/huggingface") os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/huggingface") os.environ.setdefault("HF_HUB_CACHE", "/tmp/huggingface") os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", "/tmp/huggingface/st_models") import streamlit as st import openai from collections import deque from sentence_transformers import SentenceTransformer from pinecone import Pinecone import re # Setup (exact hardcoded keys you provided) client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY")) pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY")) index = pc.Index("legal-ai") model = SentenceTransformer('all-mpnet-base-v2') chat_history = deque(maxlen=10) # last 5 pairs = 10 messages ll_model = 'gpt-4o-mini' st.title("AI Legal Assistant ⚖️") if "history" not in st.session_state: st.session_state.history = deque(maxlen=10) def get_rewritten_query(user_query): hist = list(st.session_state.history)[-4:] hist_text = "\n".join(f"{m['role']}: {m['content']}" for m in hist) messages = [ {"role": "system", "content": "You are a legal assistant that rewrites user queries into clear, context-aware queries for vector DB lookup. If its already clear then dont rewite"}, {"role": "user", "content": f"History:\n{hist_text}\n\nNew query:\n{user_query}\n\n" "Rewrite if needed for clarity/search purposes. Otherwise, repeat exactly."} ] try: resp = client.chat.completions.create( model=ll_model, messages=messages, temperature=0.1, max_tokens=400 ) rewritten = resp.choices[0].message.content.strip() except Exception as e: st.error(f"Rewrite error: {e}") rewritten = user_query # st.session_state.history.append({"role": "assistant", "content": f"🔁 Rewritten query: {rewritten}"}) return rewritten def retrieve_documents(query, top_k=10): emb = model.encode(query).tolist() try: return index.query(vector=emb, top_k=top_k, include_metadata=True)['matches'] except Exception as e: st.error(f"Retrieve error: {e}") return [] def clean_chunk_id(cid: str) -> str: """Beautify chunk_id by replacing underscores/dashes with spaces and capitalizing words.""" # Remove any trailing '_chunk_xxx' stuff cid = re.sub(r'_chunk.*$', '', cid) # Replace _ and - with spaces cid = cid.replace("_", " ").replace("-", " ") # Capitalize each word cid = " ".join(word.capitalize() for word in cid.split()) return cid def generate_response(user_query, docs): # --- Collect context --- context = "\n\n---\n\n".join(d['metadata']['text'] for d in docs) # --- Build human-friendly sources + mapping --- source_links = {} for d in docs: meta = d['metadata'] src = meta.get("source", "unknown").lower() cid = meta.get("chunk_id", "") text_preview = " ".join(meta.get("text", "").split()[:30]) if src in ["constitution"]: display_name = f"Constitution ({clean_chunk_id(cid)})" elif src in ["fbr_ordinance", "ordinance", "tax_ordinance"]: display_name = f"Tax Ordinance ({clean_chunk_id(cid)})" elif src in ["case_law", "case", "tax_case"]: display_name = f"Case Law: {text_preview}..." else: display_name = f"{src.title()} ({clean_chunk_id(cid)})" source_links[display_name] = meta.get("text", "") # Deduplicate source_links = dict(sorted(source_links.items())) # --- System prompt --- messages = [ {"role": "system", "content": "You are a helpful legal assistant. Use the provided context from documents to answer the user's question. " "At the end of your answer, write a single line starting with 'Source: ' followed by the sources used. " "Formatting rules:\n" "- For Constitution / Ordinances: show the clean chunk id, no underscores/dashes, capitalized words.\n" "- For Case law: ignore chunk id, instead show first ~30 words of the case text.\n" "- Do not use technical terms like 'chunk'. Present sources in a human-friendly way.\n" "If multiple are used, separate them with commas."} ] messages.extend(st.session_state.history) messages.append({"role": "user", "content": f"Context:\n{context}\n\n" f"Sources:\n{', '.join(source_links.keys())}\n\n" f"Question:\n{user_query}"}) try: resp = client.chat.completions.create( model=ll_model, messages=messages, temperature=0.1, max_tokens=900 ) reply = resp.choices[0].message.content.strip() except Exception as e: st.error(f"Response error: {e}") reply = "Sorry, I encountered an error generating the answer." # Optional: force clean source line if LLM misses it if source_links: clean_sources = ", ".join(source_links.keys()) if "Source:" not in reply: reply += f"\n\nSource: {clean_sources}" # Save reply into history st.session_state.history.append({"role": "assistant", "content": reply}) # --- Render in Streamlit --- st.markdown(reply) # Add expandable sources if source_links: st.write("### Sources") for name, text in source_links.items(): with st.expander(name): st.write(text) return reply # Chat UI with st.form("chat_input", clear_on_submit=True): user_input = st.text_input("You:", "") submit = st.form_submit_button("Send") if submit and user_input: st.session_state.history.append({"role": "user", "content": user_input}) rewritten = get_rewritten_query(user_input) docs = retrieve_documents(rewritten) assistant_reply = generate_response(rewritten, docs) c = 0 # Display history st.markdown("---") for msg in reversed(st.session_state.history): c+=1 if msg["role"] == "user": st.markdown(f"**You:** {msg['content']}") else: st.markdown(f"**Legal Assistant:** {msg['content']}") if c ^ 1: st.markdown("---")