Zai commited on
Commit
b6ca784
·
1 Parent(s): ee21bdf

want to test embeddings

Browse files
Files changed (2) hide show
  1. app.py +66 -26
  2. requirements.txt +2 -1
app.py CHANGED
@@ -2,6 +2,11 @@ import gradio as gr
2
  from openai import OpenAI
3
  import os
4
  from dotenv import load_dotenv
 
 
 
 
 
5
 
6
  load_dotenv(verbose=True)
7
 
@@ -10,31 +15,67 @@ client = OpenAI(
10
  api_key=os.getenv("OPENAI_API_KEY"),
11
  )
12
 
13
- def respond(
14
- message,
15
- history: list[tuple[str, str]],
16
- system_message,
17
- max_tokens,
18
- temperature,
19
- top_p,
20
- ):
21
- messages = [{"role": "system", "content": system_message}]
22
-
23
- # Convert Gradio's history into OpenAI's format
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  for user_msg, assistant_msg in history:
25
  if user_msg:
26
- messages.append({"role": "user", "content": user_msg})
27
  if assistant_msg:
28
- messages.append({"role": "assistant", "content": assistant_msg})
29
-
30
- messages.append({"role": "user", "content": message})
31
-
32
- response = ""
33
 
34
  # Stream responses from OpenAI
 
35
  stream = client.chat.completions.create(
36
- model="gpt-4o-mini", # You can change this to another available model
37
- messages=messages,
38
  max_tokens=max_tokens,
39
  temperature=temperature,
40
  top_p=top_p,
@@ -44,15 +85,14 @@ def respond(
44
  for chunk in stream:
45
  if chunk.choices[0].delta.content:
46
  token = chunk.choices[0].delta.content
47
- response += token
48
- yield response
49
-
50
 
51
- # Gradio Chat Interface
52
  demo = gr.ChatInterface(
53
  respond,
54
  additional_inputs=[
55
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
56
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
57
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
58
  gr.Slider(
@@ -60,10 +100,10 @@ demo = gr.ChatInterface(
60
  maximum=1.0,
61
  value=0.95,
62
  step=0.05,
63
- label="Top-p (nucleus sampling)",
64
  ),
65
  ],
66
  )
67
 
68
  if __name__ == "__main__":
69
- demo.launch()
 
2
  from openai import OpenAI
3
  import os
4
  from dotenv import load_dotenv
5
+ from datasets import load_dataset
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+ from functools import lru_cache
8
+ import numpy as np
9
+
10
 
11
  load_dotenv(verbose=True)
12
 
 
15
  api_key=os.getenv("OPENAI_API_KEY"),
16
  )
17
 
18
+ HF_DATASET = "carching/cs-data"
19
+ EMBED_MODEL = "text-embedding-3-small"
20
+ TOP_K = 5
21
+
22
+ # ---- Load dataset ----
23
+ def load_data():
24
+ ds = load_dataset(HF_DATASET)
25
+ messages = ds["train"]["message"]
26
+ senders = ds["train"]["sender"]
27
+ users = ds["train"]["user"]
28
+ return messages, senders, users
29
+
30
+ messages, senders, users = load_data()
31
+
32
+ # ---- Embedding helper ----
33
+ @lru_cache(maxsize=None)
34
+ def get_embedding(text):
35
+ resp = client.embeddings.create(
36
+ model=EMBED_MODEL,
37
+ input=text
38
+ )
39
+ return np.array(resp.data[0].embedding)
40
+
41
+ # ---- Precompute dataset embeddings ----
42
+ message_embeddings = [get_embedding(msg) for msg in messages]
43
+
44
+ # ---- Retrieval ----
45
+ def retrieve_context(user_query):
46
+ query_emb = get_embedding(user_query).reshape(1, -1)
47
+ all_embs = np.stack(message_embeddings)
48
+ sims = cosine_similarity(query_emb, all_embs)[0]
49
+ top_idx = np.argsort(sims)[-TOP_K:][::-1]
50
+ return [(messages[i], senders[i], sims[i]) for i in top_idx]
51
+
52
+ # ---- Chatbot function ----
53
+ def respond(message, history, system_message, max_tokens, temperature, top_p):
54
+ # Retrieve historical WhatsApp context
55
+ context_rows = retrieve_context(message)
56
+ context_text = "\n".join([f"{role}: {msg}" for msg, role, _ in context_rows])
57
+
58
+ # Build system prompt with context
59
+ full_system_message = (
60
+ f"{system_message}\n\n"
61
+ f"Use the following historical conversation context if relevant:\n\n"
62
+ f"{context_text}"
63
+ )
64
+
65
+ # Convert Gradio history into OpenAI format
66
+ chat_messages = [{"role": "system", "content": full_system_message}]
67
  for user_msg, assistant_msg in history:
68
  if user_msg:
69
+ chat_messages.append({"role": "user", "content": user_msg})
70
  if assistant_msg:
71
+ chat_messages.append({"role": "assistant", "content": assistant_msg})
72
+ chat_messages.append({"role": "user", "content": message})
 
 
 
73
 
74
  # Stream responses from OpenAI
75
+ response_text = ""
76
  stream = client.chat.completions.create(
77
+ model="gpt-4o-mini",
78
+ messages=chat_messages,
79
  max_tokens=max_tokens,
80
  temperature=temperature,
81
  top_p=top_p,
 
85
  for chunk in stream:
86
  if chunk.choices[0].delta.content:
87
  token = chunk.choices[0].delta.content
88
+ response_text += token
89
+ yield response_text
 
90
 
91
+ # ---- Gradio Interface ----
92
  demo = gr.ChatInterface(
93
  respond,
94
  additional_inputs=[
95
+ gr.Textbox(value="You are a helpful support assistant.", label="System message"),
96
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
97
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
98
  gr.Slider(
 
100
  maximum=1.0,
101
  value=0.95,
102
  step=0.05,
103
+ label="Top-p (nucleus sampling)"
104
  ),
105
  ],
106
  )
107
 
108
  if __name__ == "__main__":
109
+ demo.launch()
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  huggingface_hub
2
  gradio
3
  openai
4
- dotenv
 
 
1
  huggingface_hub
2
  gradio
3
  openai
4
+ dotenv
5
+ datasets