PraneshJs commited on
Commit
6e820af
Β·
verified Β·
1 Parent(s): 50c861a

added new features and changed embedding model

Browse files
Files changed (1) hide show
  1. app.py +33 -17
app.py CHANGED
@@ -15,8 +15,8 @@ redis_client = redis.Redis(
15
  password=os.getenv("REDIS_PASSWORD")
16
  )
17
 
18
- # 🧹 Clear Redis DB on startup
19
- redis_client.flushdb()
20
 
21
  # Azure OpenAI client (only for chat, not embeddings anymore)
22
  client = AzureOpenAI(
@@ -25,10 +25,10 @@ client = AzureOpenAI(
25
  azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT").strip()
26
  )
27
 
28
- CHAT_DEPLOYMENT = "gpt-4.1" # your Azure chat deployment
29
 
30
- # πŸš€ Better embedding model from HF
31
- embedder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
32
 
33
  # Helper: get embedding from HF
34
  def get_embedding(text):
@@ -38,11 +38,12 @@ def get_embedding(text):
38
  def cosine_similarity(vec1, vec2):
39
  return float(np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)))
40
 
41
- def search_cache(user_input, threshold=0.8):
42
  query_vec = get_embedding(user_input)
43
  best_key, best_score, best_val = None, -1, None
 
44
 
45
- for key, val in redis_client.hgetall("cache").items():
46
  entry = json.loads(val)
47
  vec = np.array(entry["embedding"], dtype=np.float32)
48
  score = cosine_similarity(query_vec, vec)
@@ -53,19 +54,24 @@ def search_cache(user_input, threshold=0.8):
53
  return best_val
54
  return None
55
 
56
- def store_cache(user_input, output):
57
  vec = get_embedding(user_input).tolist()
58
- redis_client.hset("cache", user_input, json.dumps({
 
59
  "embedding": vec,
60
  "output": output
61
  }))
62
 
63
- def chat_with_ai(user_input):
 
 
 
 
64
  if not user_input:
65
  return "Please type something."
66
 
67
  # πŸ” Check Redis semantic cache
68
- cached = search_cache(user_input)
69
  if cached:
70
  return f"[From Redis] {cached}"
71
 
@@ -79,27 +85,37 @@ def chat_with_ai(user_input):
79
  output = response.choices[0].message.content.strip()
80
 
81
  # πŸ’Ύ Save with embedding in Redis
82
- store_cache(user_input, output)
83
 
84
  return f"[From OpenAI] {output}"
85
 
86
  # Gradio UI
87
  with gr.Blocks(title="Azure OpenAI + Redis Cloud Chat") as demo:
88
- gr.Markdown("# πŸ’¬ Azure OpenAI + Redis Cloud (Semantic Cache) Demo")
 
 
 
89
  with gr.Row():
90
  chatbot = gr.Chatbot(type="messages")
 
91
  with gr.Row():
92
  msg = gr.Textbox(placeholder="Type your message here...")
93
  send = gr.Button("Send")
 
94
 
95
- def respond(message, history):
96
- bot_reply = chat_with_ai(message)
97
  history.append({"role": "user", "content": message})
98
  history.append({"role": "assistant", "content": bot_reply})
99
  return history, ""
100
 
101
- send.click(respond, [msg, chatbot], [chatbot, msg])
102
- msg.submit(respond, [msg, chatbot], [chatbot, msg])
 
 
 
 
 
103
 
104
  if __name__ == "__main__":
105
  demo.launch(server_name="0.0.0.0", server_port=7860, debug=True, pwa=True)
 
15
  password=os.getenv("REDIS_PASSWORD")
16
  )
17
 
18
+ # 🧹 Do NOT flush DB globally anymore, since multi-user support
19
+ # redis_client.flushdb()
20
 
21
  # Azure OpenAI client (only for chat, not embeddings anymore)
22
  client = AzureOpenAI(
 
25
  azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT").strip()
26
  )
27
 
28
+ CHAT_DEPLOYMENT = "gpt-4.1"
29
 
30
+ # πŸš€ Super lightweight multilingual embedding model
31
+ embedder = SentenceTransformer("intfloat/multilingual-e5-small")
32
 
33
  # Helper: get embedding from HF
34
  def get_embedding(text):
 
38
  def cosine_similarity(vec1, vec2):
39
  return float(np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)))
40
 
41
+ def search_cache(user_id, user_input, threshold=0.8):
42
  query_vec = get_embedding(user_input)
43
  best_key, best_score, best_val = None, -1, None
44
+ cache_key = f"cache:{user_id}"
45
 
46
+ for key, val in redis_client.hgetall(cache_key).items():
47
  entry = json.loads(val)
48
  vec = np.array(entry["embedding"], dtype=np.float32)
49
  score = cosine_similarity(query_vec, vec)
 
54
  return best_val
55
  return None
56
 
57
+ def store_cache(user_id, user_input, output):
58
  vec = get_embedding(user_input).tolist()
59
+ cache_key = f"cache:{user_id}"
60
+ redis_client.hset(cache_key, user_input, json.dumps({
61
  "embedding": vec,
62
  "output": output
63
  }))
64
 
65
+ def clear_user_cache(user_id):
66
+ cache_key = f"cache:{user_id}"
67
+ redis_client.delete(cache_key)
68
+
69
+ def chat_with_ai(user_id, user_input):
70
  if not user_input:
71
  return "Please type something."
72
 
73
  # πŸ” Check Redis semantic cache
74
+ cached = search_cache(user_id, user_input)
75
  if cached:
76
  return f"[From Redis] {cached}"
77
 
 
85
  output = response.choices[0].message.content.strip()
86
 
87
  # πŸ’Ύ Save with embedding in Redis
88
+ store_cache(user_id, user_input, output)
89
 
90
  return f"[From OpenAI] {output}"
91
 
92
  # Gradio UI
93
  with gr.Blocks(title="Azure OpenAI + Redis Cloud Chat") as demo:
94
+ gr.Markdown("# πŸ’¬ Azure OpenAI + Redis Cloud (Semantic Cache, Multi-User)")
95
+
96
+ user_id = gr.Textbox(label="User ID", placeholder="Enter your username", value="guest")
97
+
98
  with gr.Row():
99
  chatbot = gr.Chatbot(type="messages")
100
+
101
  with gr.Row():
102
  msg = gr.Textbox(placeholder="Type your message here...")
103
  send = gr.Button("Send")
104
+ clear = gr.Button("🧹 Clear Cache")
105
 
106
+ def respond(message, history, user_id):
107
+ bot_reply = chat_with_ai(user_id, message)
108
  history.append({"role": "user", "content": message})
109
  history.append({"role": "assistant", "content": bot_reply})
110
  return history, ""
111
 
112
+ def clear_cache_ui(user_id, history):
113
+ clear_user_cache(user_id)
114
+ return [], f"βœ… Cache cleared for {user_id}"
115
+
116
+ send.click(respond, [msg, chatbot, user_id], [chatbot, msg])
117
+ msg.submit(respond, [msg, chatbot, user_id], [chatbot, msg])
118
+ clear.click(clear_cache_ui, [user_id, chatbot], [chatbot, msg])
119
 
120
  if __name__ == "__main__":
121
  demo.launch(server_name="0.0.0.0", server_port=7860, debug=True, pwa=True)