Spaces:
Running
Running
| import os | |
| os.environ.setdefault("HF_HOME", "/tmp/huggingface") | |
| os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/huggingface") | |
| os.environ.setdefault("HF_HUB_CACHE", "/tmp/huggingface") | |
| os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", "/tmp/huggingface/st_models") | |
| import streamlit as st | |
| import openai | |
| import psycopg2 | |
| from collections import deque | |
| from sentence_transformers import SentenceTransformer | |
| import re | |
| # Setup | |
| client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| ll_model = 'gpt-4o-mini' | |
| # ββ NEW: PostgreSQL connection ββββββββββββββββββββββββββββββ | |
| def get_db_connection(): | |
| return psycopg2.connect( | |
| host=os.getenv("RDS_HOST"), | |
| port=os.getenv("RDS_PORT", 5432), | |
| dbname=os.getenv("RDS_DB"), | |
| user=os.getenv("RDS_USER"), | |
| password=os.getenv("RDS_PASS") | |
| ) | |
| # ββ NEW: BGE model ββββββββββββββββββββββββββββββββββββββββββ | |
| model = SentenceTransformer('BAAI/bge-small-en-v1.5') | |
| def retrieve_summaries(query, top_k=40): | |
| try: | |
| embedding = get_embedding(query) | |
| conn = get_db_connection() | |
| cur = conn.cursor() | |
| cur.execute(""" | |
| SELECT | |
| id, | |
| case_id, | |
| chunk_index, | |
| chunk_summary, | |
| 1 - (embedding <=> %s::vector) AS similarity | |
| FROM public.case_chunks | |
| ORDER BY embedding <=> %s::vector | |
| LIMIT %s; | |
| """, [embedding, embedding, top_k]) | |
| rows = cur.fetchall() | |
| cur.close() | |
| conn.close() | |
| return [ | |
| { | |
| "id": row[0], | |
| "case_id": row[1], | |
| "chunk_index": row[2], | |
| "chunk_summary": row[3], | |
| "similarity": row[4] | |
| } | |
| for row in rows | |
| ] | |
| except Exception as e: | |
| st.error(f"Retrieve error: {e}") | |
| return [] | |
| # ββ STEP 2: LLM picks best chunks based on summaries βββββββ | |
| def rerank_with_llm(query, candidates, final_k=10): | |
| summary_list = "\n".join([ | |
| f"[ID: {c['id']}] Case: {c['case_id']} | Summary: {c['chunk_summary']}" | |
| for c in candidates | |
| ]) | |
| messages = [ | |
| {"role": "system", "content": | |
| "You are a legal research assistant. Given a user query and a list of document chunk summaries, " | |
| "select the most relevant chunk IDs that would best answer the query. " | |
| "Return ONLY a comma-separated list of IDs, nothing else. Example: 12,45,67,23" | |
| }, | |
| {"role": "user", "content": | |
| f"Query: {query}\n\n" | |
| f"Chunks:\n{summary_list}\n\n" | |
| f"Select the {final_k} most relevant chunk IDs." | |
| } | |
| ] | |
| try: | |
| resp = client.chat.completions.create( | |
| model=ll_model, | |
| messages=messages, | |
| temperature=0.0, | |
| max_tokens=200 | |
| ) | |
| raw = resp.choices[0].message.content.strip() | |
| selected_ids = [int(i.strip()) for i in raw.split(",") if i.strip().isdigit()] | |
| return selected_ids[:final_k] | |
| except Exception as e: | |
| st.error(f"Rerank error: {e}") | |
| # Fallback: just return top final_k by similarity | |
| return [c["id"] for c in candidates[:final_k]] | |
| # ββ STEP 3: fetch full chunk_text for selected IDs only ββββ | |
| def fetch_chunks_by_ids(selected_ids): | |
| try: | |
| conn = get_db_connection() | |
| cur = conn.cursor() | |
| cur.execute(""" | |
| SELECT | |
| id, | |
| case_id, | |
| chunk_index, | |
| chunk_text, | |
| chunk_summary | |
| FROM public.case_chunks | |
| WHERE id = ANY(%s); | |
| """, [selected_ids]) | |
| rows = cur.fetchall() | |
| cur.close() | |
| conn.close() | |
| return [ | |
| { | |
| "id": row[0], | |
| "case_id": row[1], | |
| "chunk_index": row[2], | |
| "chunk_text": row[3], | |
| "chunk_summary": row[4] | |
| } | |
| for row in rows | |
| ] | |
| except Exception as e: | |
| st.error(f"Fetch error: {e}") | |
| return [] | |
| def get_embedding(text): | |
| # BGE requires this prefix for queries | |
| prefixed = f"Represent this sentence for searching relevant passages: {text}" | |
| return model.encode(prefixed).tolist() | |
| st.title("AI Legal Assistant βοΈ") | |
| if "history" not in st.session_state: | |
| st.session_state.history = deque(maxlen=10) | |
| def get_rewritten_query(user_query): | |
| hist = list(st.session_state.history)[-4:] | |
| hist_text = "\n".join(f"{m['role']}: {m['content']}" for m in hist) | |
| messages = [ | |
| {"role": "system", "content": | |
| "You are a legal assistant that rewrites user queries into clear, context-aware queries for vector DB lookup. If its already clear then dont rewrite"}, | |
| {"role": "user", "content": | |
| f"History:\n{hist_text}\n\nNew query:\n{user_query}\n\n" | |
| "Rewrite if needed for clarity/search purposes. Otherwise, repeat exactly."} | |
| ] | |
| try: | |
| resp = client.chat.completions.create( | |
| model=ll_model, | |
| messages=messages, | |
| temperature=0.1, | |
| max_tokens=400 | |
| ) | |
| rewritten = resp.choices[0].message.content.strip() | |
| except Exception as e: | |
| st.error(f"Rewrite error: {e}") | |
| rewritten = user_query | |
| return rewritten | |
| # ββ UPDATED: retrieve from pgvector ββββββββββββββββββββββββ | |
| # def retrieve_documents(query, top_k=10): | |
| # try: | |
| # embedding = get_embedding(query) | |
| # conn = get_db_connection() | |
| # cur = conn.cursor() | |
| # cur.execute(""" | |
| # SELECT | |
| # case_id, | |
| # chunk_index, | |
| # chunk_text, | |
| # chunk_summary, | |
| # 1 - (embedding <=> %s::vector) AS similarity | |
| # FROM public.case_chunks | |
| # ORDER BY embedding <=> %s::vector | |
| # LIMIT %s; | |
| # """, [embedding, embedding, top_k]) | |
| # rows = cur.fetchall() | |
| # cur.close() | |
| # conn.close() | |
| # # Format to match the rest of the app | |
| # docs = [] | |
| # for row in rows: | |
| # docs.append({ | |
| # "case_id": row[0], | |
| # "chunk_index": row[1], | |
| # "chunk_text": row[2], | |
| # "chunk_summary": row[3], | |
| # "similarity": row[4] | |
| # }) | |
| # return docs | |
| # ββ COMBINED: full retrieval pipeline ββββββββββββββββββββββ | |
| def retrieve_documents(query, top_k=10): | |
| # 1. Get 4x summaries | |
| candidates = retrieve_summaries(query, top_k=top_k * 4) | |
| if not candidates: | |
| return [] | |
| # 2. LLM picks best IDs from summaries | |
| selected_ids = rerank_with_llm(query, candidates, final_k=top_k) | |
| if not selected_ids: | |
| return [] | |
| # 3. Fetch full text for selected chunks only | |
| docs = fetch_chunks_by_ids(selected_ids) | |
| return docs | |
| except Exception as e: | |
| st.error(f"Retrieve error: {e}") | |
| return [] | |
| def clean_chunk_id(cid: str) -> str: | |
| cid = re.sub(r'_chunk.*$', '', cid) | |
| cid = cid.replace("_", " ").replace("-", " ") | |
| cid = " ".join(word.capitalize() for word in cid.split()) | |
| return cid | |
| # ββ UPDATED: generate response with new doc structure βββββββ | |
| def generate_response(user_query, docs): | |
| # Collect context from chunk_text | |
| context = "\n\n---\n\n".join(d['chunk_text'] for d in docs if d['chunk_text']) | |
| # Build sources | |
| source_links = {} | |
| for d in docs: | |
| case_id = d.get("case_id", "unknown") | |
| chunk_idx = d.get("chunk_index", "") | |
| text_preview = " ".join((d.get("chunk_text") or "").split()[:30]) | |
| if case_id == "constitution": | |
| display_name = f"Constitution (Chunk {chunk_idx})" | |
| else: | |
| display_name = f"Case Law: {text_preview}..." | |
| source_links[display_name] = d.get("chunk_text", "") | |
| source_links = dict(sorted(source_links.items())) | |
| messages = [ | |
| {"role": "system", "content": | |
| "You are a helpful legal assistant. Use the provided context from documents to answer the user's question. " | |
| "At the end of your answer, write a single line starting with 'Source: ' followed by the sources used. " | |
| "Formatting rules:\n" | |
| "- For Constitution: show the chunk number.\n" | |
| "- For Case law: show first ~30 words of the case text.\n" | |
| "- Do not use technical terms like 'chunk'. Present sources in a human-friendly way.\n" | |
| "If multiple are used, separate them with commas."} | |
| ] | |
| messages.extend(list(st.session_state.history)) | |
| messages.append({"role": "user", "content": f"Context:\n{context}\n\n" | |
| f"Sources:\n{', '.join(source_links.keys())}\n\n" | |
| f"Question:\n{user_query}"}) | |
| try: | |
| resp = client.chat.completions.create( | |
| model=ll_model, | |
| messages=messages, | |
| temperature=0.1, | |
| max_tokens=900 | |
| ) | |
| reply = resp.choices[0].message.content.strip() | |
| except Exception as e: | |
| st.error(f"Response error: {e}") | |
| reply = "Sorry, I encountered an error generating the answer." | |
| if source_links: | |
| clean_sources = ", ".join(source_links.keys()) | |
| if "Source:" not in reply: | |
| reply += f"\n\nSource: {clean_sources}" | |
| st.session_state.history.append({"role": "assistant", "content": reply}) | |
| st.markdown(reply) | |
| if source_links: | |
| st.write("### Sources") | |
| for name, text in source_links.items(): | |
| with st.expander(name): | |
| st.write(text) | |
| return reply | |
| # Chat UI | |
| with st.form("chat_input", clear_on_submit=True): | |
| user_input = st.text_input("You:", "") | |
| submit = st.form_submit_button("Send") | |
| if submit and user_input: | |
| st.session_state.history.append({"role": "user", "content": user_input}) | |
| rewritten = get_rewritten_query(user_input) | |
| docs = retrieve_documents(rewritten) | |
| assistant_reply = generate_response(rewritten, docs) | |
| c = 0 | |
| st.markdown("---") | |
| for msg in reversed(st.session_state.history): | |
| c += 1 | |
| if msg["role"] == "user": | |
| st.markdown(f"**You:** {msg['content']}") | |
| else: | |
| st.markdown(f"**Legal Assistant:** {msg['content']}") | |
| if c ^ 1: | |
| st.markdown("---") |