File size: 6,201 Bytes
f97322b
ed14dc4
6640849
 
 
 
 
59a9179
f97322b
 
 
 
8d16824
ed14dc4
f97322b
 
 
8ca6217
f97322b
 
f6d9e3b
f97322b
1bd789c
f97322b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6d9e3b
f97322b
b832b0a
 
f97322b
 
 
 
 
1bd789c
f97322b
 
cf2f7c8
f97322b
 
 
 
 
 
 
8d16824
 
 
 
 
 
 
 
 
 
 
 
 
f97322b
f6d9e3b
f97322b
f6d9e3b
 
 
408354f
 
8d16824
 
 
 
 
f6d9e3b
408354f
8d16824
f6d9e3b
408354f
8d16824
f6d9e3b
408354f
 
f6d9e3b
 
 
408354f
f6d9e3b
 
8d16824
 
113896e
 
8d16824
 
 
 
 
 
 
113896e
408354f
f97322b
f6d9e3b
8d28ad2
f6d9e3b
8d28ad2
f97322b
 
f6d9e3b
f97322b
b832b0a
 
f97322b
 
 
 
 
dc064f8
113896e
f6d9e3b
 
113896e
 
cf2f7c8
f6d9e3b
f97322b
f6d9e3b
 
 
 
 
 
 
 
 
 
 
f97322b
 
dc064f8
113896e
f6d9e3b
f97322b
 
 
 
 
 
 
 
 
 
 
cbb87db
f97322b
1bd789c
 
cbb87db
f97322b
 
 
8b3e0c0
cbb87db
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import os

os.environ.setdefault("HF_HOME", "/tmp/huggingface")
os.environ.setdefault("TRANSFORMERS_CACHE", "/tmp/huggingface")
os.environ.setdefault("HF_HUB_CACHE", "/tmp/huggingface")
os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", "/tmp/huggingface/st_models")

import streamlit as st
import openai
from collections import deque
from sentence_transformers import SentenceTransformer
from pinecone import Pinecone
import re

# Setup (exact hardcoded keys you provided)
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
index = pc.Index("legal-ai")
model = SentenceTransformer('all-mpnet-base-v2')
chat_history = deque(maxlen=10)  # last 5 pairs = 10 messages
ll_model = 'gpt-4o-mini'

st.title("AI Legal Assistant ⚖️")

if "history" not in st.session_state:
    st.session_state.history = deque(maxlen=10)

def get_rewritten_query(user_query):
    hist = list(st.session_state.history)[-4:]
    hist_text = "\n".join(f"{m['role']}: {m['content']}" for m in hist)
    messages = [
        {"role": "system", "content":
         "You are a legal assistant that rewrites user queries into clear, context-aware queries for vector DB lookup. If its already clear then dont rewite"},
        {"role": "user", "content":
         f"History:\n{hist_text}\n\nNew query:\n{user_query}\n\n"
         "Rewrite if needed for clarity/search purposes. Otherwise, repeat exactly."}
    ]
    try:
        resp = client.chat.completions.create(
            model=ll_model,
            messages=messages,
            temperature=0.1,
            max_tokens=400
        )
        rewritten = resp.choices[0].message.content.strip()
    except Exception as e:
        st.error(f"Rewrite error: {e}")
        rewritten = user_query
    # st.session_state.history.append({"role": "assistant", "content": f"🔁 Rewritten query: {rewritten}"})
    return rewritten

def retrieve_documents(query, top_k=10):
    emb = model.encode(query).tolist()
    try:
        return index.query(vector=emb, top_k=top_k, include_metadata=True)['matches']
    except Exception as e:
        st.error(f"Retrieve error: {e}")
        return []


def clean_chunk_id(cid: str) -> str:
    """Beautify chunk_id by replacing underscores/dashes with spaces and capitalizing words."""
    # Remove any trailing '_chunk_xxx' stuff
    cid = re.sub(r'_chunk.*$', '', cid)
    # Replace _ and - with spaces
    cid = cid.replace("_", " ").replace("-", " ")
    # Capitalize each word
    cid = " ".join(word.capitalize() for word in cid.split())
    return cid



def generate_response(user_query, docs):
    # --- Collect context ---
    context = "\n\n---\n\n".join(d['metadata']['text'] for d in docs)

    # --- Build human-friendly sources + mapping ---
    source_links = {}
    for d in docs:
        meta = d['metadata']
        src = meta.get("source", "unknown").lower()
        cid = meta.get("chunk_id", "")
        text_preview = " ".join(meta.get("text", "").split()[:30])

        if src in ["constitution"]:
            display_name = f"Constitution ({clean_chunk_id(cid)})"

        elif src in ["fbr_ordinance", "ordinance", "tax_ordinance"]:
            display_name = f"Tax Ordinance ({clean_chunk_id(cid)})"

        elif src in ["case_law", "case", "tax_case"]:
            display_name = f"Case Law: {text_preview}..."

        else:
            display_name = f"{src.title()} ({clean_chunk_id(cid)})"

        source_links[display_name] = meta.get("text", "")

    # Deduplicate
    source_links = dict(sorted(source_links.items()))

    # --- System prompt ---
    messages = [
        {"role": "system", "content":
         "You are a helpful legal assistant. Use the provided context from documents to answer the user's question. "
         "At the end of your answer, write a single line starting with 'Source: ' followed by the sources used. "
         "Formatting rules:\n"
         "- For Constitution / Ordinances: show the clean chunk id, no underscores/dashes, capitalized words.\n"
         "- For Case law: ignore chunk id, instead show first ~30 words of the case text.\n"
         "- Do not use technical terms like 'chunk'. Present sources in a human-friendly way.\n"
         "If multiple are used, separate them with commas."}
    ]

    messages.extend(st.session_state.history)

    messages.append({"role": "user", "content": f"Context:\n{context}\n\n"
                   f"Sources:\n{', '.join(source_links.keys())}\n\n"
                   f"Question:\n{user_query}"})
    try:
        resp = client.chat.completions.create(
            model=ll_model,
            messages=messages,
            temperature=0.1,
            max_tokens=900
        )
        reply = resp.choices[0].message.content.strip()
    except Exception as e:
        st.error(f"Response error: {e}")
        reply = "Sorry, I encountered an error generating the answer."

    # Optional: force clean source line if LLM misses it
    if source_links:
        clean_sources = ", ".join(source_links.keys())
        if "Source:" not in reply:
            reply += f"\n\nSource: {clean_sources}"

    # Save reply into history
    st.session_state.history.append({"role": "assistant", "content": reply})

    # --- Render in Streamlit ---
    st.markdown(reply)

    # Add expandable sources
    if source_links:
        st.write("### Sources")
        for name, text in source_links.items():
            with st.expander(name):
                st.write(text)

    return reply




# Chat UI
with st.form("chat_input", clear_on_submit=True):
    user_input = st.text_input("You:", "")
    submit = st.form_submit_button("Send")

if submit and user_input:
    st.session_state.history.append({"role": "user", "content": user_input})
    rewritten = get_rewritten_query(user_input)
    docs = retrieve_documents(rewritten)
    assistant_reply = generate_response(rewritten, docs)

c = 0
# Display history
st.markdown("---")
for msg in reversed(st.session_state.history):
    c+=1
    if msg["role"] == "user":
        st.markdown(f"**You:** {msg['content']}")
    else:
        st.markdown(f"**Legal Assistant:** {msg['content']}")
    if c ^ 1: st.markdown("---")