PraneshJs commited on
Commit
142282b
·
verified ·
1 Parent(s): cc33aed

improved the embedding for languages

Browse files
Files changed (1) hide show
  1. app.py +169 -72
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import gradio as gr
3
  import redis
4
  import numpy as np
@@ -7,122 +8,215 @@ from datetime import timedelta
7
  from openai import AzureOpenAI
8
  from sentence_transformers import SentenceTransformer
9
 
10
- # Redis Cloud connection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  redis_client = redis.Redis(
12
- host="redis-12628.c14.us-east-1-2.ec2.redns.redis-cloud.com",
13
- port=12628,
14
  decode_responses=True,
15
- username="default",
16
- password=os.getenv("REDIS_PASSWORD")
17
  )
18
 
19
- # Azure OpenAI client
20
  client = AzureOpenAI(
21
- api_key=os.getenv("AZURE_OPENAI_API_KEY").strip(),
22
- api_version="2025-01-01-preview",
23
- azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT").strip()
24
  )
25
 
26
- CHAT_DEPLOYMENT = "gpt-4.1"
27
-
28
- # 🚀 Super lightweight multilingual embedding model
29
  embedder = SentenceTransformer("intfloat/multilingual-e5-small")
30
 
31
- # Cache expiration: 2 days (in seconds)
32
- CACHE_TTL = int(timedelta(days=2).total_seconds())
33
-
34
- # Helper: get embedding
35
- def get_embedding(text):
36
- return embedder.encode(text, convert_to_numpy=True).astype(np.float32)
37
-
38
- # Helper: cosine similarity
39
- def cosine_similarity(vec1, vec2):
40
- return float(np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- def search_cache(user_id, user_input, threshold=0.9): # stricter threshold
43
- query_vec = get_embedding(user_input)
44
- best_key, best_score, best_val = None, -1, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  cache_key = f"cache:{user_id}"
 
 
 
 
 
 
 
 
 
46
 
47
- for key, val in redis_client.hgetall(cache_key).items():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  entry = json.loads(val)
 
 
49
  vec = np.array(entry["embedding"], dtype=np.float32)
50
  score = cosine_similarity(query_vec, vec)
51
  if score > best_score:
52
- best_score, best_key, best_val = score, key, entry["output"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- if best_score >= threshold:
55
- return best_val
56
  return None
57
 
58
- def store_cache(user_id, user_input, output):
59
- vec = get_embedding(user_input).tolist()
60
- cache_key = f"cache:{user_id}"
61
-
62
- # Store extra context: include language keyword if present
63
- context_input = user_input.lower()
64
- if "java" in context_input:
65
- context_input = "JAVA: " + context_input
66
- elif "python" in context_input:
67
- context_input = "PYTHON: " + context_input
68
- elif "c++" in context_input or "cpp" in context_input:
69
- context_input = "CPP: " + context_input
70
- elif "c " in context_input:
71
- context_input = "C: " + context_input
72
-
73
- redis_client.hset(cache_key, context_input, json.dumps({
74
- "embedding": vec,
75
- "output": output
76
- }))
77
- redis_client.expire(cache_key, CACHE_TTL)
78
 
79
- def clear_user_cache(user_id):
80
- cache_key = f"cache:{user_id}"
81
- redis_client.delete(cache_key)
82
-
83
- def view_user_cache(user_id):
84
  cache_key = f"cache:{user_id}"
85
  entries = redis_client.hgetall(cache_key)
86
  if not entries:
87
  return "⚠️ No cache stored."
88
  lines = []
89
- for q, val in entries.items():
90
- entry = json.loads(val)
91
- lines.append(f"**Q:** {q}\n**A:** {entry['output']}")
 
 
 
92
  return "\n\n---\n\n".join(lines)
93
 
94
- def chat_with_ai(user_id, user_input):
95
- if not user_input:
96
- return "Please type something."
 
 
 
97
 
98
- # 🔍 Check Redis semantic cache
99
  cached = search_cache(user_id, user_input)
100
  if cached:
101
  return f"[From Redis] {cached}"
102
 
103
- # Otherwise query Azure OpenAI
104
  response = client.chat.completions.create(
105
  model=CHAT_DEPLOYMENT,
106
  messages=[{"role": "user", "content": user_input}],
107
  temperature=0.8,
108
- max_tokens=700
109
  )
110
  output = response.choices[0].message.content.strip()
111
 
112
- # 💾 Save with embedding in Redis
113
  store_cache(user_id, user_input, output)
114
-
115
  return f"[From OpenAI] {output}"
116
 
 
117
  # Gradio UI
118
- with gr.Blocks(title="Azure OpenAI + Redis Cloud Chat") as demo:
119
- gr.Markdown("# 💬 Azure OpenAI + Redis Cloud (Semantic Cache, Multi-User, Auto Clean)")
 
120
 
121
  user_id_state = gr.State("")
122
 
123
  with gr.Row():
124
  user_id_input = gr.Textbox(label="Enter Username (only once)", placeholder="Your username")
125
  save_user = gr.Button("✅ Save Username")
 
126
 
127
  with gr.Row():
128
  chatbot = gr.Chatbot(type="messages")
@@ -132,11 +226,14 @@ with gr.Blocks(title="Azure OpenAI + Redis Cloud Chat") as demo:
132
  send = gr.Button("Send")
133
 
134
  with gr.Row():
135
- clear = gr.Button("🧹 Clear Cache")
136
- view = gr.Button("👀 View Cache")
137
  cache_output = gr.Markdown("")
138
 
139
- def set_user_id(uid):
 
 
 
140
  return uid, f"✅ Username set as **{uid}**"
141
 
142
  def respond(message, history, user_id):
@@ -158,11 +255,11 @@ with gr.Blocks(title="Azure OpenAI + Redis Cloud Chat") as demo:
158
  return "⚠️ Please set username first!"
159
  return view_user_cache(user_id)
160
 
161
- save_user.click(set_user_id, user_id_input, [user_id_state, cache_output])
162
  send.click(respond, [msg, chatbot, user_id_state], [chatbot, msg])
163
  msg.submit(respond, [msg, chatbot, user_id_state], [chatbot, msg])
164
  clear.click(clear_cache_ui, [user_id_state, chatbot], [chatbot, cache_output])
165
  view.click(view_cache_ui, user_id_state, cache_output)
166
 
167
  if __name__ == "__main__":
168
- demo.launch(server_name="0.0.0.0", server_port=7860, debug=True, pwa=True)
 
1
  import os
2
+ import re
3
  import gradio as gr
4
  import redis
5
  import numpy as np
 
8
  from openai import AzureOpenAI
9
  from sentence_transformers import SentenceTransformer
10
 
11
+ # -----------------------
12
+ # Configuration
13
+ # -----------------------
14
+ REDIS_HOST = "redis-12628.c14.us-east-1-2.ec2.redns.redis-cloud.com"
15
+ REDIS_PORT = 12628
16
+ REDIS_USER = "default"
17
+ REDIS_PASSWORD = os.getenv("REDIS_PASSWORD")
18
+
19
+ AZURE_API_KEY = os.getenv("AZURE_OPENAI_API_KEY", "").strip()
20
+ AZURE_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT", "").strip()
21
+ AZURE_API_VERSION = "2025-01-01-preview"
22
+ CHAT_DEPLOYMENT = "gpt-4.1"
23
+
24
+ # Cache TTL (2 days)
25
+ CACHE_TTL = int(timedelta(days=2).total_seconds())
26
+
27
+ # Matching thresholds
28
+ PRIMARY_THRESHOLD = 0.90 # for same-language matches
29
+ FALLBACK_THRESHOLD = 0.95 # for language-agnostic fallback (very strict)
30
+
31
+ # -----------------------
32
+ # Clients / Models
33
+ # -----------------------
34
  redis_client = redis.Redis(
35
+ host=REDIS_HOST,
36
+ port=REDIS_PORT,
37
  decode_responses=True,
38
+ username=REDIS_USER,
39
+ password=REDIS_PASSWORD,
40
  )
41
 
 
42
  client = AzureOpenAI(
43
+ api_key=AZURE_API_KEY,
44
+ api_version=AZURE_API_VERSION,
45
+ azure_endpoint=AZURE_ENDPOINT,
46
  )
47
 
48
+ # Embedding model (multilingual, small & strong)
 
 
49
  embedder = SentenceTransformer("intfloat/multilingual-e5-small")
50
 
51
+ # -----------------------
52
+ # Helpers
53
+ # -----------------------
54
+ def detect_language_tag(text: str):
55
+ """Return a language tag string (lowercase) or None."""
56
+ t = text.lower()
57
+ patterns = [
58
+ (r'\bjava\b', "java"),
59
+ (r'\bpython\b', "python"),
60
+ (r'\b(c\+\+|cpp)\b', "cpp"),
61
+ (r'\bc#\b|\bcsharp\b', "csharp"),
62
+ (r'\bjavascript\b|\bjs\b', "javascript"),
63
+ (r'\b(go|golang)\b', "go"),
64
+ (r'\bruby\b', "ruby"),
65
+ (r'\bphp\b', "php"),
66
+ (r'\bscala\b', "scala"),
67
+ (r'\br\b', "r"),
68
+ # C detection is tricky; look for " in c", " c language", or standalone " c "
69
+ (r'\b in c\b|\bc language\b|\b c \b', "c"),
70
+ ]
71
+ for pat, tag in patterns:
72
+ if re.search(pat, t):
73
+ return tag
74
+ return None
75
 
76
+ def build_embedding_input(text: str, lang_tag: str | None):
77
+ """Create the text to embed: include language tag prefix if present."""
78
+ if lang_tag:
79
+ return f"{lang_tag.upper()}: {text}"
80
+ return text
81
+
82
+ def get_embedding(text: str) -> np.ndarray:
83
+ vec = embedder.encode(text, convert_to_numpy=True)
84
+ return vec.astype(np.float32)
85
+
86
+ def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
87
+ # safe guard against zero vectors
88
+ n1 = np.linalg.norm(vec1)
89
+ n2 = np.linalg.norm(vec2)
90
+ if n1 == 0 or n2 == 0:
91
+ return 0.0
92
+ return float(np.dot(vec1, vec2) / (n1 * n2))
93
+
94
+ # -----------------------
95
+ # Cache functions
96
+ # -----------------------
97
+ def store_cache(user_id: str, user_input: str, output: str):
98
+ lang = detect_language_tag(user_input)
99
+ embed_text = build_embedding_input(user_input, lang)
100
+ vec = get_embedding(embed_text).tolist()
101
  cache_key = f"cache:{user_id}"
102
+ store_key = (f"{lang}:" + user_input) if lang else user_input
103
+ payload = {
104
+ "orig": user_input,
105
+ "embedding": vec,
106
+ "output": output,
107
+ "lang": lang,
108
+ }
109
+ redis_client.hset(cache_key, store_key, json.dumps(payload))
110
+ redis_client.expire(cache_key, CACHE_TTL)
111
 
112
+ def search_cache(user_id: str, user_input: str, primary_threshold=PRIMARY_THRESHOLD, fallback_threshold=FALLBACK_THRESHOLD):
113
+ cache_key = f"cache:{user_id}"
114
+ entries = redis_client.hgetall(cache_key)
115
+ if not entries:
116
+ return None
117
+
118
+ # detect language and make embedding with same prefix logic
119
+ detected_lang = detect_language_tag(user_input)
120
+ query_embed_text = build_embedding_input(user_input, detected_lang)
121
+ query_vec = get_embedding(query_embed_text)
122
+
123
+ # 1) Try same-language matches (if language detected)
124
+ best_score = -1.0
125
+ best_output = None
126
+ if detected_lang:
127
+ for _, val in entries.items():
128
+ entry = json.loads(val)
129
+ if entry.get("lang") != detected_lang:
130
+ continue
131
+ vec = np.array(entry["embedding"], dtype=np.float32)
132
+ score = cosine_similarity(query_vec, vec)
133
+ if score > best_score:
134
+ best_score, best_output = score, entry["output"]
135
+ if best_score >= primary_threshold:
136
+ return best_output
137
+
138
+ # 2) Try language-agnostic entries (lang == None)
139
+ best_score = -1.0
140
+ best_output = None
141
+ for _, val in entries.items():
142
  entry = json.loads(val)
143
+ if entry.get("lang") is not None:
144
+ continue
145
  vec = np.array(entry["embedding"], dtype=np.float32)
146
  score = cosine_similarity(query_vec, vec)
147
  if score > best_score:
148
+ best_score, best_output = score, entry["output"]
149
+ if best_score >= fallback_threshold:
150
+ return best_output
151
+
152
+ # 3) Final fallback: search any language but require very high similarity
153
+ best_score = -1.0
154
+ best_output = None
155
+ for _, val in entries.items():
156
+ entry = json.loads(val)
157
+ vec = np.array(entry["embedding"], dtype=np.float32)
158
+ score = cosine_similarity(query_vec, vec)
159
+ if score > best_score:
160
+ best_score, best_output = score, entry["output"]
161
+ if best_score >= fallback_threshold:
162
+ return best_output
163
 
 
 
164
  return None
165
 
166
+ def clear_user_cache(user_id: str):
167
+ redis_client.delete(f"cache:{user_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
+ def view_user_cache(user_id: str):
 
 
 
 
170
  cache_key = f"cache:{user_id}"
171
  entries = redis_client.hgetall(cache_key)
172
  if not entries:
173
  return "⚠️ No cache stored."
174
  lines = []
175
+ for k, v in entries.items():
176
+ entry = json.loads(v)
177
+ lang = entry.get("lang") or "general"
178
+ q = entry.get("orig", k)
179
+ a = entry.get("output", "")
180
+ lines.append(f"**Lang:** {lang}\n**Q:** {q}\n**A:** {a}")
181
  return "\n\n---\n\n".join(lines)
182
 
183
+ # -----------------------
184
+ # Chat logic
185
+ # -----------------------
186
+ def chat_with_ai(user_id: str, user_input: str):
187
+ if not user_input or not user_id:
188
+ return "Please set a username and type something."
189
 
190
+ # 1) semantic cache search (language-aware)
191
  cached = search_cache(user_id, user_input)
192
  if cached:
193
  return f"[From Redis] {cached}"
194
 
195
+ # 2) fallback to Azure OpenAI
196
  response = client.chat.completions.create(
197
  model=CHAT_DEPLOYMENT,
198
  messages=[{"role": "user", "content": user_input}],
199
  temperature=0.8,
200
+ max_tokens=700,
201
  )
202
  output = response.choices[0].message.content.strip()
203
 
204
+ # store with language-aware embedding
205
  store_cache(user_id, user_input, output)
 
206
  return f"[From OpenAI] {output}"
207
 
208
+ # -----------------------
209
  # Gradio UI
210
+ # -----------------------
211
+ with gr.Blocks(title="Azure OpenAI + Redis Cloud Chat (Lang-aware)") as demo:
212
+ gr.Markdown("# 💬 Azure OpenAI + Redis Cloud (Language-aware Semantic Cache)")
213
 
214
  user_id_state = gr.State("")
215
 
216
  with gr.Row():
217
  user_id_input = gr.Textbox(label="Enter Username (only once)", placeholder="Your username")
218
  save_user = gr.Button("✅ Save Username")
219
+ user_status = gr.Markdown("")
220
 
221
  with gr.Row():
222
  chatbot = gr.Chatbot(type="messages")
 
226
  send = gr.Button("Send")
227
 
228
  with gr.Row():
229
+ clear = gr.Button("🧹 Clear My Cache")
230
+ view = gr.Button("👀 View My Cache")
231
  cache_output = gr.Markdown("")
232
 
233
+ def set_user_id(uid: str):
234
+ uid = uid.strip()
235
+ if not uid:
236
+ return "", "⚠️ Please enter a non-empty username."
237
  return uid, f"✅ Username set as **{uid}**"
238
 
239
  def respond(message, history, user_id):
 
255
  return "⚠️ Please set username first!"
256
  return view_user_cache(user_id)
257
 
258
+ save_user.click(set_user_id, user_id_input, [user_id_state, user_status])
259
  send.click(respond, [msg, chatbot, user_id_state], [chatbot, msg])
260
  msg.submit(respond, [msg, chatbot, user_id_state], [chatbot, msg])
261
  clear.click(clear_cache_ui, [user_id_state, chatbot], [chatbot, cache_output])
262
  view.click(view_cache_ui, user_id_state, cache_output)
263
 
264
  if __name__ == "__main__":
265
+ demo.launch(server_name="0.0.0.0", server_port=7860, debug=True, pwa=True)