Spaces:
Sleeping
Sleeping
| import os | |
| os.environ.setdefault("HF_HOME", "/tmp/huggingface") | |
| os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/huggingface") | |
| os.environ.setdefault("HF_HUB_CACHE", "/tmp/huggingface") | |
| os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", "/tmp/huggingface/st_models") | |
| import streamlit as st | |
| import openai | |
| from collections import deque | |
| from sentence_transformers import SentenceTransformer | |
| from pinecone import Pinecone | |
| import re | |
| # Setup (exact hardcoded keys you provided) | |
| client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY")) | |
| index = pc.Index("legal-ai") | |
| model = SentenceTransformer('all-mpnet-base-v2') | |
| chat_history = deque(maxlen=10) # last 5 pairs = 10 messages | |
| ll_model = 'gpt-4o-mini' | |
| st.title("AI Legal Assistant โ๏ธ") | |
| if "history" not in st.session_state: | |
| st.session_state.history = deque(maxlen=10) | |
| def get_rewritten_query(user_query): | |
| hist = list(st.session_state.history)[-4:] | |
| hist_text = "\n".join(f"{m['role']}: {m['content']}" for m in hist) | |
| messages = [ | |
| {"role": "system", "content": | |
| "You are a legal assistant that rewrites user queries into clear, context-aware queries for vector DB lookup. If its already clear then dont rewite"}, | |
| {"role": "user", "content": | |
| f"History:\n{hist_text}\n\nNew query:\n{user_query}\n\n" | |
| "Rewrite if needed for clarity/search purposes. Otherwise, repeat exactly."} | |
| ] | |
| try: | |
| resp = client.chat.completions.create( | |
| model=ll_model, | |
| messages=messages, | |
| temperature=0.1, | |
| max_tokens=400 | |
| ) | |
| rewritten = resp.choices[0].message.content.strip() | |
| except Exception as e: | |
| st.error(f"Rewrite error: {e}") | |
| rewritten = user_query | |
| # st.session_state.history.append({"role": "assistant", "content": f"๐ Rewritten query: {rewritten}"}) | |
| return rewritten | |
| def retrieve_documents(query, top_k=10): | |
| emb = model.encode(query).tolist() | |
| try: | |
| return index.query(vector=emb, top_k=top_k, include_metadata=True)['matches'] | |
| except Exception as e: | |
| st.error(f"Retrieve error: {e}") | |
| return [] | |
| def clean_chunk_id(cid: str) -> str: | |
| """Beautify chunk_id by replacing underscores/dashes with spaces and capitalizing words.""" | |
| # Remove any trailing '_chunk_xxx' stuff | |
| cid = re.sub(r'_chunk.*$', '', cid) | |
| # Replace _ and - with spaces | |
| cid = cid.replace("_", " ").replace("-", " ") | |
| # Capitalize each word | |
| cid = " ".join(word.capitalize() for word in cid.split()) | |
| return cid | |
| def generate_response(user_query, docs): | |
| # --- Collect context --- | |
| context = "\n\n---\n\n".join(d['metadata']['text'] for d in docs) | |
| # --- Build human-friendly sources + mapping --- | |
| source_links = {} | |
| for d in docs: | |
| meta = d['metadata'] | |
| src = meta.get("source", "unknown").lower() | |
| cid = meta.get("chunk_id", "") | |
| text_preview = " ".join(meta.get("text", "").split()[:30]) | |
| if src in ["constitution"]: | |
| display_name = f"Constitution ({clean_chunk_id(cid)})" | |
| elif src in ["fbr_ordinance", "ordinance", "tax_ordinance"]: | |
| display_name = f"Tax Ordinance ({clean_chunk_id(cid)})" | |
| elif src in ["case_law", "case", "tax_case"]: | |
| display_name = f"Case Law: {text_preview}..." | |
| else: | |
| display_name = f"{src.title()} ({clean_chunk_id(cid)})" | |
| source_links[display_name] = meta.get("text", "") | |
| # Deduplicate | |
| source_links = dict(sorted(source_links.items())) | |
| # --- System prompt --- | |
| messages = [ | |
| {"role": "system", "content": | |
| "You are a helpful legal assistant. Use the provided context from documents to answer the user's question. " | |
| "At the end of your answer, write a single line starting with 'Source: ' followed by the sources used. " | |
| "Formatting rules:\n" | |
| "- For Constitution / Ordinances: show the clean chunk id, no underscores/dashes, capitalized words.\n" | |
| "- For Case law: ignore chunk id, instead show first ~30 words of the case text.\n" | |
| "- Do not use technical terms like 'chunk'. Present sources in a human-friendly way.\n" | |
| "If multiple are used, separate them with commas."} | |
| ] | |
| messages.extend(st.session_state.history) | |
| messages.append({"role": "user", "content": f"Context:\n{context}\n\n" | |
| f"Sources:\n{', '.join(source_links.keys())}\n\n" | |
| f"Question:\n{user_query}"}) | |
| try: | |
| resp = client.chat.completions.create( | |
| model=ll_model, | |
| messages=messages, | |
| temperature=0.1, | |
| max_tokens=900 | |
| ) | |
| reply = resp.choices[0].message.content.strip() | |
| except Exception as e: | |
| st.error(f"Response error: {e}") | |
| reply = "Sorry, I encountered an error generating the answer." | |
| # Optional: force clean source line if LLM misses it | |
| if source_links: | |
| clean_sources = ", ".join(source_links.keys()) | |
| if "Source:" not in reply: | |
| reply += f"\n\nSource: {clean_sources}" | |
| # Save reply into history | |
| st.session_state.history.append({"role": "assistant", "content": reply}) | |
| # --- Render in Streamlit --- | |
| st.markdown(reply) | |
| # Add expandable sources | |
| if source_links: | |
| st.write("### Sources") | |
| for name, text in source_links.items(): | |
| with st.expander(name): | |
| st.write(text) | |
| return reply | |
| # Chat UI | |
| with st.form("chat_input", clear_on_submit=True): | |
| user_input = st.text_input("You:", "") | |
| submit = st.form_submit_button("Send") | |
| if submit and user_input: | |
| st.session_state.history.append({"role": "user", "content": user_input}) | |
| rewritten = get_rewritten_query(user_input) | |
| docs = retrieve_documents(rewritten) | |
| assistant_reply = generate_response(rewritten, docs) | |
| c = 0 | |
| # Display history | |
| st.markdown("---") | |
| for msg in reversed(st.session_state.history): | |
| c+=1 | |
| if msg["role"] == "user": | |
| st.markdown(f"**You:** {msg['content']}") | |
| else: | |
| st.markdown(f"**Legal Assistant:** {msg['content']}") | |
| if c ^ 1: st.markdown("---") |