devilsa commited on
Commit
9abfbc3
·
verified ·
1 Parent(s): 31eb18a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -24
app.py CHANGED
@@ -1,41 +1,45 @@
1
  import os
 
2
  import streamlit as st
3
  import faiss
4
- import numpy as np
5
  from sentence_transformers import SentenceTransformer
6
  from groq import Groq
7
 
8
- # --- Load API key from environment (HF Repo Secrets) ---
9
  API_KEY = os.getenv("GROQ_API_KEY")
10
  if not API_KEY:
11
  st.error(
12
- "GROQ_API_KEY not found. In your Space go to: "
13
- "Settings → Repository secrets → Add new secret (Name: GROQ_API_KEY, Value: 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)."
 
 
14
  )
15
  st.stop()
16
 
17
- # --- Init Groq client (no key printed/logged) ---
18
  client = Groq(api_key=API_KEY)
19
 
20
- # --- Cache the embedding model to speed up reloads ---
21
  @st.cache_resource
22
  def load_embedder():
 
23
  return SentenceTransformer("all-MiniLM-L6-v2")
24
 
25
  embedding_model = load_embedder()
26
 
27
- # --- FAISS index (384 dims for MiniLM) ---
28
- dimension = 384
29
  if "faiss_index" not in st.session_state:
30
- st.session_state.faiss_index = faiss.IndexFlatL2(dimension)
31
  if "chunks_store" not in st.session_state:
32
  st.session_state.chunks_store = []
33
 
34
  index = st.session_state.faiss_index
35
  chunks_store = st.session_state.chunks_store
36
 
37
- # ---- Utilities ----
38
- def chunk_text(text, max_length=500):
 
 
39
  words, chunks, cur = text.split(), [], []
40
  for w in words:
41
  if len(" ".join(cur)) + len(w) + 1 <= max_length:
@@ -47,25 +51,27 @@ def chunk_text(text, max_length=500):
47
  chunks.append(" ".join(cur))
48
  return chunks
49
 
 
50
  def embed_and_store(chunks):
51
  if not chunks:
52
  return
53
- embs = embedding_model.encode(chunks, convert_to_numpy=True, normalize_embeddings=False)
54
- # Ensure float32 for FAISS
55
- embs = np.asarray(embs, dtype="float32")
56
  index.add(embs)
57
  chunks_store.extend(chunks)
58
 
 
59
  def query_llm(prompt: str) -> str:
60
- # Streaming chat completion
61
  stream = client.chat.completions.create(
62
  model="deepseek-r1-distill-llama-70b",
63
  messages=[
64
  {
65
  "role": "system",
66
  "content": (
67
- "You are a relationship counselor. Analyze the WhatsApp conversation and "
68
- "provide insights on red flags, toxicity, and improvements. "
69
  "Start every answer with: 'Toxicity score: X/10'."
70
  ),
71
  },
@@ -83,7 +89,9 @@ def query_llm(prompt: str) -> str:
83
  out.append(delta)
84
  return "".join(out)
85
 
86
- # ---- UI ----
 
 
87
  st.title("AI Relationship Counsellor")
88
 
89
  uploaded_file = st.file_uploader("Upload a .txt export of your WhatsApp chat", type=["txt"])
@@ -97,13 +105,12 @@ if uploaded_file:
97
 
98
  user_query = st.text_input("Ask a question about your relationship:")
99
  if user_query:
100
- # Search top-k relevant chunks
101
- k = min(5, index.ntotal) if index.ntotal > 0 else 0
102
- if k == 0:
103
- st.warning("No text indexed yet. Please upload a chat file.")
104
  else:
105
- q_emb = embedding_model.encode([user_query], convert_to_numpy=True)
106
- q_emb = np.asarray(q_emb, dtype="float32")
 
107
  distances, idxs = index.search(q_emb, k)
108
  relevant = [chunks_store[i] for i in idxs[0] if 0 <= i < len(chunks_store)]
109
 
@@ -115,3 +122,5 @@ if uploaded_file:
115
 
116
  st.markdown("### AI Analysis")
117
  st.write(answer)
 
 
 
1
  import os
2
+ import numpy as np
3
  import streamlit as st
4
  import faiss
 
5
  from sentence_transformers import SentenceTransformer
6
  from groq import Groq
7
 
8
+ # ---------- Secrets / API Key ----------
9
  API_KEY = os.getenv("GROQ_API_KEY")
10
  if not API_KEY:
11
  st.error(
12
+ "GROQ_API_KEY not found.\n\n"
13
+ "Go to your Space → Settings → Repository secrets → Add new secret\n"
14
+ "Name: GROQ_API_KEY | Value: <your Groq key>\n\n"
15
+ "Then Restart/Restart this Space."
16
  )
17
  st.stop()
18
 
19
+ # ---------- Groq Client ----------
20
  client = Groq(api_key=API_KEY)
21
 
22
+ # ---------- Models / Index ----------
23
  @st.cache_resource
24
  def load_embedder():
25
+ # 384-dim embeddings
26
  return SentenceTransformer("all-MiniLM-L6-v2")
27
 
28
  embedding_model = load_embedder()
29
 
30
+ DIM = 384 # all-MiniLM-L6-v2 dimension
 
31
  if "faiss_index" not in st.session_state:
32
+ st.session_state.faiss_index = faiss.IndexFlatL2(DIM)
33
  if "chunks_store" not in st.session_state:
34
  st.session_state.chunks_store = []
35
 
36
  index = st.session_state.faiss_index
37
  chunks_store = st.session_state.chunks_store
38
 
39
+
40
+ # ---------- Helpers ----------
41
+ def chunk_text(text: str, max_length: int = 500):
42
+ """Simple whitespace chunker by character budget."""
43
  words, chunks, cur = text.split(), [], []
44
  for w in words:
45
  if len(" ".join(cur)) + len(w) + 1 <= max_length:
 
51
  chunks.append(" ".join(cur))
52
  return chunks
53
 
54
+
55
  def embed_and_store(chunks):
56
  if not chunks:
57
  return
58
+ embs = embedding_model.encode(
59
+ chunks, convert_to_numpy=True, normalize_embeddings=False
60
+ ).astype("float32")
61
  index.add(embs)
62
  chunks_store.extend(chunks)
63
 
64
+
65
  def query_llm(prompt: str) -> str:
66
+ """Stream a response from Groq and return full text."""
67
  stream = client.chat.completions.create(
68
  model="deepseek-r1-distill-llama-70b",
69
  messages=[
70
  {
71
  "role": "system",
72
  "content": (
73
+ "You are a relationship counselor. Analyze the WhatsApp conversation "
74
+ "and provide insights on red flags, toxicity, and improvements. "
75
  "Start every answer with: 'Toxicity score: X/10'."
76
  ),
77
  },
 
89
  out.append(delta)
90
  return "".join(out)
91
 
92
+
93
+ # ---------- UI ----------
94
+ st.set_page_config(page_title="AI Relationship Counsellor", layout="centered")
95
  st.title("AI Relationship Counsellor")
96
 
97
  uploaded_file = st.file_uploader("Upload a .txt export of your WhatsApp chat", type=["txt"])
 
105
 
106
  user_query = st.text_input("Ask a question about your relationship:")
107
  if user_query:
108
+ if index.ntotal == 0:
109
+ st.warning("Nothing indexed yet. Please upload a chat file.")
 
 
110
  else:
111
+ # top-k retrieval
112
+ k = min(5, index.ntotal)
113
+ q_emb = embedding_model.encode([user_query], convert_to_numpy=True).astype("float32")
114
  distances, idxs = index.search(q_emb, k)
115
  relevant = [chunks_store[i] for i in idxs[0] if 0 <= i < len(chunks_store)]
116
 
 
122
 
123
  st.markdown("### AI Analysis")
124
  st.write(answer)
125
+ else:
126
+ st.info("Upload a WhatsApp chat (.txt) to begin.")