omarkashif commited on
Commit
f97322b
·
verified ·
1 Parent(s): 9806424

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +89 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,91 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
 
4
  import streamlit as st
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # os.environ.setdefault("HF_HOME", "/home/user/huggingface_cache")
3
+ # os.environ.setdefault("TRANSFORMERS_CACHE", "/home/user/huggingface_cache/transformers")
4
+ # os.environ.setdefault("SENTENCE_TRANSFORMERS_HOME", "/home/user/huggingface_cache/sentence_transformers")
5
  import streamlit as st
6
+ import openai
7
+ from collections import deque
8
+ from sentence_transformers import SentenceTransformer
9
+ from pinecone import Pinecone
10
 
11
+ # Setup (exact hardcoded keys you provided)
12
+ client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
13
+ pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
14
+ index = pc.Index("legal-bot")
15
+ model = SentenceTransformer('all-mpnet-base-v2')
16
+ chat_history = deque(maxlen=10) # last 5 pairs = 10 messages
17
+
18
+ st.title("🔍 Legal RAG Assistant (Streamlit)")
19
+
20
+ if "history" not in st.session_state:
21
+ st.session_state.history = deque(maxlen=10)
22
+
23
+ def get_rewritten_query(user_query):
24
+ hist = list(st.session_state.history)[-4:]
25
+ hist_text = "\n".join(f"{m['role']}: {m['content']}" for m in hist)
26
+ messages = [
27
+ {"role": "system", "content":
28
+ "You are a legal assistant that rewrites user queries into clear, context-aware queries for vector DB lookup. If its already clear then dont rewite"},
29
+ {"role": "user", "content":
30
+ f"History:\n{hist_text}\n\nNew query:\n{user_query}\n\n"
31
+ "Rewrite if needed for clarity/search purposes. Otherwise, repeat exactly."}
32
+ ]
33
+ try:
34
+ resp = client.chat.completions.create(
35
+ model="gpt-4o-mini",
36
+ messages=messages,
37
+ temperature=0.3,
38
+ max_tokens=100
39
+ )
40
+ rewritten = resp.choices[0].message.content.strip()
41
+ except Exception as e:
42
+ st.error(f"Rewrite error: {e}")
43
+ rewritten = user_query
44
+ st.session_state.history.append({"role": "assistant", "content": f"🔁 Rewritten query: {rewritten}"})
45
+ return rewritten
46
+
47
+ def retrieve_documents(query, top_k=5):
48
+ emb = model.encode(query).tolist()
49
+ try:
50
+ return index.query(vector=emb, top_k=top_k, include_metadata=True)['matches']
51
+ except Exception as e:
52
+ st.error(f"Retrieve error: {e}")
53
+ return []
54
+
55
+ def generate_response(user_query, docs):
56
+ context = "\n\n---\n\n".join(d['metadata']['text'] for d in docs)
57
+ messages = [{"role": "system", "content":
58
+ "You are a helpful legal assistant. Use provided context from documents. Answer only using the context."}]
59
+ messages.extend(st.session_state.history)
60
+ messages.append({"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{user_query}"})
61
+ try:
62
+ resp = client.chat.completions.create(
63
+ model="gpt-4o-mini",
64
+ messages=messages,
65
+ temperature=0.2,
66
+ max_tokens=500
67
+ )
68
+ reply = resp.choices[0].message.content.strip()
69
+ except Exception as e:
70
+ st.error(f"Response error: {e}")
71
+ reply = "Sorry, I encountered an error generating the answer."
72
+ st.session_state.history.append({"role": "assistant", "content": reply})
73
+ return reply
74
+
75
+ # Chat UI
76
+ with st.form("chat_input", clear_on_submit=True):
77
+ user_input = st.text_input("You:", "")
78
+ submit = st.form_submit_button("Send")
79
+
80
+ if submit and user_input:
81
+ st.session_state.history.append({"role": "user", "content": user_input})
82
+ rewritten = get_rewritten_query(user_input)
83
+ docs = retrieve_documents(rewritten)
84
+ assistant_reply = generate_response(rewritten, docs)
85
+
86
+ # Display history
87
+ for msg in st.session_state.history:
88
+ if msg["role"] == "user":
89
+ st.markdown(f"**You:** {msg['content']}")
90
+ else:
91
+ st.markdown(f"**Assistant:** {msg['content']}")