Zai commited on
Commit
c623717
·
1 Parent(s): cae3947

feat: implement basic RAG with pinecone

Browse files
Files changed (2) hide show
  1. app.py +67 -93
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,109 +1,83 @@
 
 
 
1
  import gradio as gr
2
  from openai import OpenAI
3
- import os
4
  from dotenv import load_dotenv
5
- from datasets import load_dataset
6
- from sklearn.metrics.pairwise import cosine_similarity
7
- from functools import lru_cache
8
- import numpy as np
9
-
10
 
11
  load_dotenv(verbose=True)
12
 
13
- # Create OpenAI client (make sure OPENAI_API_KEY is set in your environment)
14
- client = OpenAI(
15
- api_key=os.getenv("OPENAI_API_KEY"),
16
- )
17
-
18
- HF_DATASET = "carching/cs-data"
19
  EMBED_MODEL = "text-embedding-3-small"
20
  TOP_K = 5
21
 
22
- # ---- Load dataset ----
23
- def load_data():
24
- ds = load_dataset(HF_DATASET)
25
- messages = ds["train"]["message"]
26
- senders = ds["train"]["sender"]
27
- users = ds["train"]["user"]
28
- return messages, senders, users
29
 
30
- messages, senders, users = load_data()
31
-
32
- # ---- Embedding helper ----
33
- @lru_cache(maxsize=None)
34
- def get_embedding(text):
35
- resp = client.embeddings.create(
36
  model=EMBED_MODEL,
37
- input=text
38
  )
39
- return np.array(resp.data[0].embedding)
40
-
41
- # ---- Precompute dataset embeddings ----
42
- message_embeddings = [get_embedding(msg) for msg in messages]
43
-
44
- # ---- Retrieval ----
45
- def retrieve_context(user_query):
46
- query_emb = get_embedding(user_query).reshape(1, -1)
47
- all_embs = np.stack(message_embeddings)
48
- sims = cosine_similarity(query_emb, all_embs)[0]
49
- top_idx = np.argsort(sims)[-TOP_K:][::-1]
50
- return [(messages[i], senders[i], sims[i]) for i in top_idx]
51
-
52
- # ---- Chatbot function ----
53
- def respond(message, history, system_message, max_tokens, temperature, top_p):
54
- # Retrieve historical WhatsApp context
55
- context_rows = retrieve_context(message)
56
- context_text = "\n".join([f"{role}: {msg}" for msg, role, _ in context_rows])
57
-
58
- # Build system prompt with context
59
- full_system_message = (
60
- f"{system_message}\n\n"
61
- f"Use the following historical conversation context if relevant:\n\n"
62
- f"{context_text}"
63
  )
64
 
65
- # Convert Gradio history into OpenAI format
66
- chat_messages = [{"role": "system", "content": full_system_message}]
67
- for user_msg, assistant_msg in history:
68
- if user_msg:
69
- chat_messages.append({"role": "user", "content": user_msg})
70
- if assistant_msg:
71
- chat_messages.append({"role": "assistant", "content": assistant_msg})
72
- chat_messages.append({"role": "user", "content": message})
73
-
74
- # Stream responses from OpenAI
75
- response_text = ""
76
- stream = client.chat.completions.create(
77
- model="gpt-4o-mini",
78
- messages=chat_messages,
79
- max_tokens=max_tokens,
80
- temperature=temperature,
81
- top_p=top_p,
82
- stream=True
83
  )
84
 
85
- for chunk in stream:
86
- if chunk.choices[0].delta.content:
87
- token = chunk.choices[0].delta.content
88
- response_text += token
89
- yield response_text
90
-
91
- # ---- Gradio Interface ----
92
- demo = gr.ChatInterface(
93
- respond,
94
- additional_inputs=[
95
- gr.Textbox(value="You are a helpful support assistant.", label="System message"),
96
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
97
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
98
- gr.Slider(
99
- minimum=0.1,
100
- maximum=1.0,
101
- value=0.95,
102
- step=0.05,
103
- label="Top-p (nucleus sampling)"
104
- ),
105
- ],
106
- )
107
-
108
- if __name__ == "__main__":
109
- demo.launch()
 
 
 
1
+ import os
2
+ import json
3
+
4
  import gradio as gr
5
  from openai import OpenAI
6
+ from pinecone import Pinecone
7
  from dotenv import load_dotenv
 
 
 
 
 
8
 
9
  load_dotenv(verbose=True)
10
 
11
+ PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
12
+ INDEX_NAME = "whatsapp-history-1"
 
 
 
 
13
  EMBED_MODEL = "text-embedding-3-small"
14
  TOP_K = 5
15
 
16
+ client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
17
+ pc = Pinecone(api_key=PINECONE_API_KEY)
18
+ index = pc.Index(INDEX_NAME)
 
 
 
 
19
 
20
+ def retrieve_context(query):
21
+ response = client.embeddings.create(
 
 
 
 
22
  model=EMBED_MODEL,
23
+ input=query
24
  )
25
+ query_emb = response.data[0].embedding
26
+
27
+ # Use keyword argument 'vector' for the query
28
+ result = index.query(vector=query_emb, top_k=TOP_K, include_metadata=True)
29
+
30
+ contexts = []
31
+ for match in result.matches:
32
+ meta = match.metadata
33
+ contexts.append(f"{meta['sender']}: {meta['message']}")
34
+
35
+ return "\n".join(contexts)
36
+
37
+ def respond(message, chat_history_json):
38
+ chat_history = json.loads(chat_history_json)
39
+
40
+ context = retrieve_context(message)
41
+
42
+ system_prompt = (
43
+ "You are a helpful assistant for carching. Use the following past conversation data on whatsapp "
44
+ "to answer the user's question if relevant:\n\n" + context
 
 
 
 
45
  )
46
 
47
+ messages = [{"role": "system", "content": system_prompt}]
48
+ messages.extend(chat_history or [])
49
+ messages.append({"role": "user", "content": message})
50
+
51
+ response = client.chat.completions.create(
52
+ model="gpt-5-mini",
53
+ messages=messages,
54
+ temperature=0.7
 
 
 
 
 
 
 
 
 
 
55
  )
56
 
57
+ bot_reply = response.choices[0].message.content
58
+
59
+ chat_history.append({"role": "user", "content": message})
60
+ chat_history.append({"role": "assistant", "content": bot_reply})
61
+
62
+ # Return the history directly and the serialized state
63
+ return chat_history, json.dumps(chat_history)
64
+
65
+ with gr.Blocks() as demo:
66
+ gr.Markdown("# Customer Support Chatbot")
67
+ # Set chatbot type to 'messages' to resolve the warning
68
+ chatbot = gr.Chatbot(type="messages")
69
+ msg = gr.Textbox(placeholder="Ask a question...", show_label=False)
70
+ state = gr.State(json.dumps([]))
71
+
72
+ with gr.Row():
73
+ submit_btn = gr.Button("Send")
74
+
75
+ def submit_message(msg, state):
76
+ # Clear the textbox after submission
77
+ return "", *respond(msg, state)
78
+
79
+ # Bind submit button AND hitting enter in textbox
80
+ submit_btn.click(fn=submit_message, inputs=[msg, state], outputs=[msg, chatbot, state])
81
+ msg.submit(fn=submit_message, inputs=[msg, state], outputs=[msg, chatbot, state])
82
+
83
+ demo.launch()
requirements.txt CHANGED
@@ -4,4 +4,5 @@ openai
4
  dotenv
5
  datasets
6
  scikit-learn
7
- numpy
 
 
4
  dotenv
5
  datasets
6
  scikit-learn
7
+ numpy
8
+ pinecone