Spaces:
Sleeping
Sleeping
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +32 -17
src/streamlit_app.py
CHANGED
|
@@ -18,6 +18,7 @@ pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
|
|
| 18 |
index = pc.Index("legal-ai")
|
| 19 |
model = SentenceTransformer('all-mpnet-base-v2')
|
| 20 |
chat_history = deque(maxlen=10) # last 5 pairs = 10 messages
|
|
|
|
| 21 |
|
| 22 |
st.title("AI Legal Assistant ⚖️")
|
| 23 |
|
|
@@ -36,7 +37,7 @@ def get_rewritten_query(user_query):
|
|
| 36 |
]
|
| 37 |
try:
|
| 38 |
resp = client.chat.completions.create(
|
| 39 |
-
model=
|
| 40 |
messages=messages,
|
| 41 |
temperature=0.1,
|
| 42 |
max_tokens=400
|
|
@@ -70,10 +71,11 @@ def clean_chunk_id(cid: str) -> str:
|
|
| 70 |
|
| 71 |
|
| 72 |
def generate_response(user_query, docs):
|
|
|
|
| 73 |
context = "\n\n---\n\n".join(d['metadata']['text'] for d in docs)
|
| 74 |
-
|
| 75 |
-
# --- Build human-friendly sources ---
|
| 76 |
-
|
| 77 |
for d in docs:
|
| 78 |
meta = d['metadata']
|
| 79 |
src = meta.get("source", "unknown").lower()
|
|
@@ -81,20 +83,21 @@ def generate_response(user_query, docs):
|
|
| 81 |
text_preview = " ".join(meta.get("text", "").split()[:30])
|
| 82 |
|
| 83 |
if src in ["constitution"]:
|
| 84 |
-
|
| 85 |
|
| 86 |
elif src in ["fbr_ordinance", "ordinance", "tax_ordinance"]:
|
| 87 |
-
|
| 88 |
|
| 89 |
elif src in ["case_law", "case", "tax_case"]:
|
| 90 |
-
|
| 91 |
-
readable_sources.append(f"Case Law: {text_preview}...")
|
| 92 |
|
| 93 |
else:
|
| 94 |
-
|
|
|
|
|
|
|
| 95 |
|
| 96 |
-
# Deduplicate
|
| 97 |
-
|
| 98 |
|
| 99 |
# --- System prompt ---
|
| 100 |
messages = [
|
|
@@ -108,15 +111,14 @@ def generate_response(user_query, docs):
|
|
| 108 |
"If multiple are used, separate them with commas."}
|
| 109 |
]
|
| 110 |
|
| 111 |
-
|
| 112 |
messages.extend(st.session_state.history)
|
| 113 |
-
|
| 114 |
messages.append({"role": "user", "content": f"Context:\n{context}\n\n"
|
| 115 |
-
f"Sources:\n{', '.join(
|
| 116 |
f"Question:\n{user_query}"})
|
| 117 |
try:
|
| 118 |
resp = client.chat.completions.create(
|
| 119 |
-
model=
|
| 120 |
messages=messages,
|
| 121 |
temperature=0.1,
|
| 122 |
max_tokens=900
|
|
@@ -127,16 +129,29 @@ def generate_response(user_query, docs):
|
|
| 127 |
reply = "Sorry, I encountered an error generating the answer."
|
| 128 |
|
| 129 |
# Optional: force clean source line if LLM misses it
|
| 130 |
-
if
|
| 131 |
-
clean_sources = ", ".join(
|
| 132 |
if "Source:" not in reply:
|
| 133 |
reply += f"\n\nSource: {clean_sources}"
|
| 134 |
|
|
|
|
| 135 |
st.session_state.history.append({"role": "assistant", "content": reply})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
return reply
|
| 137 |
|
| 138 |
|
| 139 |
|
|
|
|
| 140 |
# Chat UI
|
| 141 |
with st.form("chat_input", clear_on_submit=True):
|
| 142 |
user_input = st.text_input("You:", "")
|
|
|
|
| 18 |
index = pc.Index("legal-ai")
|
| 19 |
model = SentenceTransformer('all-mpnet-base-v2')
|
| 20 |
chat_history = deque(maxlen=10) # last 5 pairs = 10 messages
|
| 21 |
+
ll_model = 'gpt-4o-mini'
|
| 22 |
|
| 23 |
st.title("AI Legal Assistant ⚖️")
|
| 24 |
|
|
|
|
| 37 |
]
|
| 38 |
try:
|
| 39 |
resp = client.chat.completions.create(
|
| 40 |
+
model=ll_model,
|
| 41 |
messages=messages,
|
| 42 |
temperature=0.1,
|
| 43 |
max_tokens=400
|
|
|
|
| 71 |
|
| 72 |
|
| 73 |
def generate_response(user_query, docs):
|
| 74 |
+
# --- Collect context ---
|
| 75 |
context = "\n\n---\n\n".join(d['metadata']['text'] for d in docs)
|
| 76 |
+
|
| 77 |
+
# --- Build human-friendly sources + mapping ---
|
| 78 |
+
source_links = {}
|
| 79 |
for d in docs:
|
| 80 |
meta = d['metadata']
|
| 81 |
src = meta.get("source", "unknown").lower()
|
|
|
|
| 83 |
text_preview = " ".join(meta.get("text", "").split()[:30])
|
| 84 |
|
| 85 |
if src in ["constitution"]:
|
| 86 |
+
display_name = f"Constitution ({clean_chunk_id(cid)})"
|
| 87 |
|
| 88 |
elif src in ["fbr_ordinance", "ordinance", "tax_ordinance"]:
|
| 89 |
+
display_name = f"Tax Ordinance ({clean_chunk_id(cid)})"
|
| 90 |
|
| 91 |
elif src in ["case_law", "case", "tax_case"]:
|
| 92 |
+
display_name = f"Case Law: {text_preview}..."
|
|
|
|
| 93 |
|
| 94 |
else:
|
| 95 |
+
display_name = f"{src.title()} ({clean_chunk_id(cid)})"
|
| 96 |
+
|
| 97 |
+
source_links[display_name] = meta.get("text", "")
|
| 98 |
|
| 99 |
+
# Deduplicate
|
| 100 |
+
source_links = dict(sorted(source_links.items()))
|
| 101 |
|
| 102 |
# --- System prompt ---
|
| 103 |
messages = [
|
|
|
|
| 111 |
"If multiple are used, separate them with commas."}
|
| 112 |
]
|
| 113 |
|
|
|
|
| 114 |
messages.extend(st.session_state.history)
|
| 115 |
+
|
| 116 |
messages.append({"role": "user", "content": f"Context:\n{context}\n\n"
|
| 117 |
+
f"Sources:\n{', '.join(source_links.keys())}\n\n"
|
| 118 |
f"Question:\n{user_query}"})
|
| 119 |
try:
|
| 120 |
resp = client.chat.completions.create(
|
| 121 |
+
model=ll_model,
|
| 122 |
messages=messages,
|
| 123 |
temperature=0.1,
|
| 124 |
max_tokens=900
|
|
|
|
| 129 |
reply = "Sorry, I encountered an error generating the answer."
|
| 130 |
|
| 131 |
# Optional: force clean source line if LLM misses it
|
| 132 |
+
if source_links:
|
| 133 |
+
clean_sources = ", ".join(source_links.keys())
|
| 134 |
if "Source:" not in reply:
|
| 135 |
reply += f"\n\nSource: {clean_sources}"
|
| 136 |
|
| 137 |
+
# Save reply into history
|
| 138 |
st.session_state.history.append({"role": "assistant", "content": reply})
|
| 139 |
+
|
| 140 |
+
# --- Render in Streamlit ---
|
| 141 |
+
st.markdown(reply)
|
| 142 |
+
|
| 143 |
+
# Add expandable sources
|
| 144 |
+
if source_links:
|
| 145 |
+
st.write("### Sources")
|
| 146 |
+
for name, text in source_links.items():
|
| 147 |
+
with st.expander(name):
|
| 148 |
+
st.write(text)
|
| 149 |
+
|
| 150 |
return reply
|
| 151 |
|
| 152 |
|
| 153 |
|
| 154 |
+
|
| 155 |
# Chat UI
|
| 156 |
with st.form("chat_input", clear_on_submit=True):
|
| 157 |
user_input = st.text_input("You:", "")
|