Spaces:
Sleeping
Sleeping
improved the embedding for languages
Browse files
app.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
import redis
|
| 4 |
import numpy as np
|
|
@@ -7,122 +8,215 @@ from datetime import timedelta
|
|
| 7 |
from openai import AzureOpenAI
|
| 8 |
from sentence_transformers import SentenceTransformer
|
| 9 |
|
| 10 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
redis_client = redis.Redis(
|
| 12 |
-
host=
|
| 13 |
-
port=
|
| 14 |
decode_responses=True,
|
| 15 |
-
username=
|
| 16 |
-
password=
|
| 17 |
)
|
| 18 |
|
| 19 |
-
# Azure OpenAI client
|
| 20 |
client = AzureOpenAI(
|
| 21 |
-
api_key=
|
| 22 |
-
api_version=
|
| 23 |
-
azure_endpoint=
|
| 24 |
)
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
# 🚀 Super lightweight multilingual embedding model
|
| 29 |
embedder = SentenceTransformer("intfloat/multilingual-e5-small")
|
| 30 |
|
| 31 |
-
#
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
def
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
cache_key = f"cache:{user_id}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
entry = json.loads(val)
|
|
|
|
|
|
|
| 49 |
vec = np.array(entry["embedding"], dtype=np.float32)
|
| 50 |
score = cosine_similarity(query_vec, vec)
|
| 51 |
if score > best_score:
|
| 52 |
-
best_score,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
if best_score >= threshold:
|
| 55 |
-
return best_val
|
| 56 |
return None
|
| 57 |
|
| 58 |
-
def
|
| 59 |
-
|
| 60 |
-
cache_key = f"cache:{user_id}"
|
| 61 |
-
|
| 62 |
-
# Store extra context: include language keyword if present
|
| 63 |
-
context_input = user_input.lower()
|
| 64 |
-
if "java" in context_input:
|
| 65 |
-
context_input = "JAVA: " + context_input
|
| 66 |
-
elif "python" in context_input:
|
| 67 |
-
context_input = "PYTHON: " + context_input
|
| 68 |
-
elif "c++" in context_input or "cpp" in context_input:
|
| 69 |
-
context_input = "CPP: " + context_input
|
| 70 |
-
elif "c " in context_input:
|
| 71 |
-
context_input = "C: " + context_input
|
| 72 |
-
|
| 73 |
-
redis_client.hset(cache_key, context_input, json.dumps({
|
| 74 |
-
"embedding": vec,
|
| 75 |
-
"output": output
|
| 76 |
-
}))
|
| 77 |
-
redis_client.expire(cache_key, CACHE_TTL)
|
| 78 |
|
| 79 |
-
def
|
| 80 |
-
cache_key = f"cache:{user_id}"
|
| 81 |
-
redis_client.delete(cache_key)
|
| 82 |
-
|
| 83 |
-
def view_user_cache(user_id):
|
| 84 |
cache_key = f"cache:{user_id}"
|
| 85 |
entries = redis_client.hgetall(cache_key)
|
| 86 |
if not entries:
|
| 87 |
return "⚠️ No cache stored."
|
| 88 |
lines = []
|
| 89 |
-
for
|
| 90 |
-
entry = json.loads(
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
| 92 |
return "\n\n---\n\n".join(lines)
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
-
#
|
| 99 |
cached = search_cache(user_id, user_input)
|
| 100 |
if cached:
|
| 101 |
return f"[From Redis] {cached}"
|
| 102 |
|
| 103 |
-
#
|
| 104 |
response = client.chat.completions.create(
|
| 105 |
model=CHAT_DEPLOYMENT,
|
| 106 |
messages=[{"role": "user", "content": user_input}],
|
| 107 |
temperature=0.8,
|
| 108 |
-
max_tokens=700
|
| 109 |
)
|
| 110 |
output = response.choices[0].message.content.strip()
|
| 111 |
|
| 112 |
-
#
|
| 113 |
store_cache(user_id, user_input, output)
|
| 114 |
-
|
| 115 |
return f"[From OpenAI] {output}"
|
| 116 |
|
|
|
|
| 117 |
# Gradio UI
|
| 118 |
-
|
| 119 |
-
|
|
|
|
| 120 |
|
| 121 |
user_id_state = gr.State("")
|
| 122 |
|
| 123 |
with gr.Row():
|
| 124 |
user_id_input = gr.Textbox(label="Enter Username (only once)", placeholder="Your username")
|
| 125 |
save_user = gr.Button("✅ Save Username")
|
|
|
|
| 126 |
|
| 127 |
with gr.Row():
|
| 128 |
chatbot = gr.Chatbot(type="messages")
|
|
@@ -132,11 +226,14 @@ with gr.Blocks(title="Azure OpenAI + Redis Cloud Chat") as demo:
|
|
| 132 |
send = gr.Button("Send")
|
| 133 |
|
| 134 |
with gr.Row():
|
| 135 |
-
clear = gr.Button("🧹 Clear Cache")
|
| 136 |
-
view = gr.Button("👀 View Cache")
|
| 137 |
cache_output = gr.Markdown("")
|
| 138 |
|
| 139 |
-
def set_user_id(uid):
|
|
|
|
|
|
|
|
|
|
| 140 |
return uid, f"✅ Username set as **{uid}**"
|
| 141 |
|
| 142 |
def respond(message, history, user_id):
|
|
@@ -158,11 +255,11 @@ with gr.Blocks(title="Azure OpenAI + Redis Cloud Chat") as demo:
|
|
| 158 |
return "⚠️ Please set username first!"
|
| 159 |
return view_user_cache(user_id)
|
| 160 |
|
| 161 |
-
save_user.click(set_user_id, user_id_input, [user_id_state,
|
| 162 |
send.click(respond, [msg, chatbot, user_id_state], [chatbot, msg])
|
| 163 |
msg.submit(respond, [msg, chatbot, user_id_state], [chatbot, msg])
|
| 164 |
clear.click(clear_cache_ui, [user_id_state, chatbot], [chatbot, cache_output])
|
| 165 |
view.click(view_cache_ui, user_id_state, cache_output)
|
| 166 |
|
| 167 |
if __name__ == "__main__":
|
| 168 |
-
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True, pwa=True)
|
|
|
|
| 1 |
import os
|
| 2 |
+
import re
|
| 3 |
import gradio as gr
|
| 4 |
import redis
|
| 5 |
import numpy as np
|
|
|
|
| 8 |
from openai import AzureOpenAI
|
| 9 |
from sentence_transformers import SentenceTransformer
|
| 10 |
|
| 11 |
+
# -----------------------
|
| 12 |
+
# Configuration
|
| 13 |
+
# -----------------------
|
| 14 |
+
REDIS_HOST = "redis-12628.c14.us-east-1-2.ec2.redns.redis-cloud.com"
|
| 15 |
+
REDIS_PORT = 12628
|
| 16 |
+
REDIS_USER = "default"
|
| 17 |
+
REDIS_PASSWORD = os.getenv("REDIS_PASSWORD")
|
| 18 |
+
|
| 19 |
+
AZURE_API_KEY = os.getenv("AZURE_OPENAI_API_KEY", "").strip()
|
| 20 |
+
AZURE_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT", "").strip()
|
| 21 |
+
AZURE_API_VERSION = "2025-01-01-preview"
|
| 22 |
+
CHAT_DEPLOYMENT = "gpt-4.1"
|
| 23 |
+
|
| 24 |
+
# Cache TTL (2 days)
|
| 25 |
+
CACHE_TTL = int(timedelta(days=2).total_seconds())
|
| 26 |
+
|
| 27 |
+
# Matching thresholds
|
| 28 |
+
PRIMARY_THRESHOLD = 0.90 # for same-language matches
|
| 29 |
+
FALLBACK_THRESHOLD = 0.95 # for language-agnostic fallback (very strict)
|
| 30 |
+
|
| 31 |
+
# -----------------------
|
| 32 |
+
# Clients / Models
|
| 33 |
+
# -----------------------
|
| 34 |
redis_client = redis.Redis(
|
| 35 |
+
host=REDIS_HOST,
|
| 36 |
+
port=REDIS_PORT,
|
| 37 |
decode_responses=True,
|
| 38 |
+
username=REDIS_USER,
|
| 39 |
+
password=REDIS_PASSWORD,
|
| 40 |
)
|
| 41 |
|
|
|
|
| 42 |
client = AzureOpenAI(
|
| 43 |
+
api_key=AZURE_API_KEY,
|
| 44 |
+
api_version=AZURE_API_VERSION,
|
| 45 |
+
azure_endpoint=AZURE_ENDPOINT,
|
| 46 |
)
|
| 47 |
|
| 48 |
+
# Embedding model (multilingual, small & strong)
|
|
|
|
|
|
|
| 49 |
embedder = SentenceTransformer("intfloat/multilingual-e5-small")
|
| 50 |
|
| 51 |
+
# -----------------------
|
| 52 |
+
# Helpers
|
| 53 |
+
# -----------------------
|
| 54 |
+
def detect_language_tag(text: str):
|
| 55 |
+
"""Return a language tag string (lowercase) or None."""
|
| 56 |
+
t = text.lower()
|
| 57 |
+
patterns = [
|
| 58 |
+
(r'\bjava\b', "java"),
|
| 59 |
+
(r'\bpython\b', "python"),
|
| 60 |
+
(r'\b(c\+\+|cpp)\b', "cpp"),
|
| 61 |
+
(r'\bc#\b|\bcsharp\b', "csharp"),
|
| 62 |
+
(r'\bjavascript\b|\bjs\b', "javascript"),
|
| 63 |
+
(r'\b(go|golang)\b', "go"),
|
| 64 |
+
(r'\bruby\b', "ruby"),
|
| 65 |
+
(r'\bphp\b', "php"),
|
| 66 |
+
(r'\bscala\b', "scala"),
|
| 67 |
+
(r'\br\b', "r"),
|
| 68 |
+
# C detection is tricky; look for " in c", " c language", or standalone " c "
|
| 69 |
+
(r'\b in c\b|\bc language\b|\b c \b', "c"),
|
| 70 |
+
]
|
| 71 |
+
for pat, tag in patterns:
|
| 72 |
+
if re.search(pat, t):
|
| 73 |
+
return tag
|
| 74 |
+
return None
|
| 75 |
|
| 76 |
+
def build_embedding_input(text: str, lang_tag: str | None):
|
| 77 |
+
"""Create the text to embed: include language tag prefix if present."""
|
| 78 |
+
if lang_tag:
|
| 79 |
+
return f"{lang_tag.upper()}: {text}"
|
| 80 |
+
return text
|
| 81 |
+
|
| 82 |
+
def get_embedding(text: str) -> np.ndarray:
|
| 83 |
+
vec = embedder.encode(text, convert_to_numpy=True)
|
| 84 |
+
return vec.astype(np.float32)
|
| 85 |
+
|
| 86 |
+
def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
|
| 87 |
+
# safe guard against zero vectors
|
| 88 |
+
n1 = np.linalg.norm(vec1)
|
| 89 |
+
n2 = np.linalg.norm(vec2)
|
| 90 |
+
if n1 == 0 or n2 == 0:
|
| 91 |
+
return 0.0
|
| 92 |
+
return float(np.dot(vec1, vec2) / (n1 * n2))
|
| 93 |
+
|
| 94 |
+
# -----------------------
|
| 95 |
+
# Cache functions
|
| 96 |
+
# -----------------------
|
| 97 |
+
def store_cache(user_id: str, user_input: str, output: str):
|
| 98 |
+
lang = detect_language_tag(user_input)
|
| 99 |
+
embed_text = build_embedding_input(user_input, lang)
|
| 100 |
+
vec = get_embedding(embed_text).tolist()
|
| 101 |
cache_key = f"cache:{user_id}"
|
| 102 |
+
store_key = (f"{lang}:" + user_input) if lang else user_input
|
| 103 |
+
payload = {
|
| 104 |
+
"orig": user_input,
|
| 105 |
+
"embedding": vec,
|
| 106 |
+
"output": output,
|
| 107 |
+
"lang": lang,
|
| 108 |
+
}
|
| 109 |
+
redis_client.hset(cache_key, store_key, json.dumps(payload))
|
| 110 |
+
redis_client.expire(cache_key, CACHE_TTL)
|
| 111 |
|
| 112 |
+
def search_cache(user_id: str, user_input: str, primary_threshold=PRIMARY_THRESHOLD, fallback_threshold=FALLBACK_THRESHOLD):
|
| 113 |
+
cache_key = f"cache:{user_id}"
|
| 114 |
+
entries = redis_client.hgetall(cache_key)
|
| 115 |
+
if not entries:
|
| 116 |
+
return None
|
| 117 |
+
|
| 118 |
+
# detect language and make embedding with same prefix logic
|
| 119 |
+
detected_lang = detect_language_tag(user_input)
|
| 120 |
+
query_embed_text = build_embedding_input(user_input, detected_lang)
|
| 121 |
+
query_vec = get_embedding(query_embed_text)
|
| 122 |
+
|
| 123 |
+
# 1) Try same-language matches (if language detected)
|
| 124 |
+
best_score = -1.0
|
| 125 |
+
best_output = None
|
| 126 |
+
if detected_lang:
|
| 127 |
+
for _, val in entries.items():
|
| 128 |
+
entry = json.loads(val)
|
| 129 |
+
if entry.get("lang") != detected_lang:
|
| 130 |
+
continue
|
| 131 |
+
vec = np.array(entry["embedding"], dtype=np.float32)
|
| 132 |
+
score = cosine_similarity(query_vec, vec)
|
| 133 |
+
if score > best_score:
|
| 134 |
+
best_score, best_output = score, entry["output"]
|
| 135 |
+
if best_score >= primary_threshold:
|
| 136 |
+
return best_output
|
| 137 |
+
|
| 138 |
+
# 2) Try language-agnostic entries (lang == None)
|
| 139 |
+
best_score = -1.0
|
| 140 |
+
best_output = None
|
| 141 |
+
for _, val in entries.items():
|
| 142 |
entry = json.loads(val)
|
| 143 |
+
if entry.get("lang") is not None:
|
| 144 |
+
continue
|
| 145 |
vec = np.array(entry["embedding"], dtype=np.float32)
|
| 146 |
score = cosine_similarity(query_vec, vec)
|
| 147 |
if score > best_score:
|
| 148 |
+
best_score, best_output = score, entry["output"]
|
| 149 |
+
if best_score >= fallback_threshold:
|
| 150 |
+
return best_output
|
| 151 |
+
|
| 152 |
+
# 3) Final fallback: search any language but require very high similarity
|
| 153 |
+
best_score = -1.0
|
| 154 |
+
best_output = None
|
| 155 |
+
for _, val in entries.items():
|
| 156 |
+
entry = json.loads(val)
|
| 157 |
+
vec = np.array(entry["embedding"], dtype=np.float32)
|
| 158 |
+
score = cosine_similarity(query_vec, vec)
|
| 159 |
+
if score > best_score:
|
| 160 |
+
best_score, best_output = score, entry["output"]
|
| 161 |
+
if best_score >= fallback_threshold:
|
| 162 |
+
return best_output
|
| 163 |
|
|
|
|
|
|
|
| 164 |
return None
|
| 165 |
|
| 166 |
+
def clear_user_cache(user_id: str):
|
| 167 |
+
redis_client.delete(f"cache:{user_id}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
+
def view_user_cache(user_id: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
cache_key = f"cache:{user_id}"
|
| 171 |
entries = redis_client.hgetall(cache_key)
|
| 172 |
if not entries:
|
| 173 |
return "⚠️ No cache stored."
|
| 174 |
lines = []
|
| 175 |
+
for k, v in entries.items():
|
| 176 |
+
entry = json.loads(v)
|
| 177 |
+
lang = entry.get("lang") or "general"
|
| 178 |
+
q = entry.get("orig", k)
|
| 179 |
+
a = entry.get("output", "")
|
| 180 |
+
lines.append(f"**Lang:** {lang}\n**Q:** {q}\n**A:** {a}")
|
| 181 |
return "\n\n---\n\n".join(lines)
|
| 182 |
|
| 183 |
+
# -----------------------
|
| 184 |
+
# Chat logic
|
| 185 |
+
# -----------------------
|
| 186 |
+
def chat_with_ai(user_id: str, user_input: str):
|
| 187 |
+
if not user_input or not user_id:
|
| 188 |
+
return "Please set a username and type something."
|
| 189 |
|
| 190 |
+
# 1) semantic cache search (language-aware)
|
| 191 |
cached = search_cache(user_id, user_input)
|
| 192 |
if cached:
|
| 193 |
return f"[From Redis] {cached}"
|
| 194 |
|
| 195 |
+
# 2) fallback to Azure OpenAI
|
| 196 |
response = client.chat.completions.create(
|
| 197 |
model=CHAT_DEPLOYMENT,
|
| 198 |
messages=[{"role": "user", "content": user_input}],
|
| 199 |
temperature=0.8,
|
| 200 |
+
max_tokens=700,
|
| 201 |
)
|
| 202 |
output = response.choices[0].message.content.strip()
|
| 203 |
|
| 204 |
+
# store with language-aware embedding
|
| 205 |
store_cache(user_id, user_input, output)
|
|
|
|
| 206 |
return f"[From OpenAI] {output}"
|
| 207 |
|
| 208 |
+
# -----------------------
|
| 209 |
# Gradio UI
|
| 210 |
+
# -----------------------
|
| 211 |
+
with gr.Blocks(title="Azure OpenAI + Redis Cloud Chat (Lang-aware)") as demo:
|
| 212 |
+
gr.Markdown("# 💬 Azure OpenAI + Redis Cloud (Language-aware Semantic Cache)")
|
| 213 |
|
| 214 |
user_id_state = gr.State("")
|
| 215 |
|
| 216 |
with gr.Row():
|
| 217 |
user_id_input = gr.Textbox(label="Enter Username (only once)", placeholder="Your username")
|
| 218 |
save_user = gr.Button("✅ Save Username")
|
| 219 |
+
user_status = gr.Markdown("")
|
| 220 |
|
| 221 |
with gr.Row():
|
| 222 |
chatbot = gr.Chatbot(type="messages")
|
|
|
|
| 226 |
send = gr.Button("Send")
|
| 227 |
|
| 228 |
with gr.Row():
|
| 229 |
+
clear = gr.Button("🧹 Clear My Cache")
|
| 230 |
+
view = gr.Button("👀 View My Cache")
|
| 231 |
cache_output = gr.Markdown("")
|
| 232 |
|
| 233 |
+
def set_user_id(uid: str):
|
| 234 |
+
uid = uid.strip()
|
| 235 |
+
if not uid:
|
| 236 |
+
return "", "⚠️ Please enter a non-empty username."
|
| 237 |
return uid, f"✅ Username set as **{uid}**"
|
| 238 |
|
| 239 |
def respond(message, history, user_id):
|
|
|
|
| 255 |
return "⚠️ Please set username first!"
|
| 256 |
return view_user_cache(user_id)
|
| 257 |
|
| 258 |
+
save_user.click(set_user_id, user_id_input, [user_id_state, user_status])
|
| 259 |
send.click(respond, [msg, chatbot, user_id_state], [chatbot, msg])
|
| 260 |
msg.submit(respond, [msg, chatbot, user_id_state], [chatbot, msg])
|
| 261 |
clear.click(clear_cache_ui, [user_id_state, chatbot], [chatbot, cache_output])
|
| 262 |
view.click(view_cache_ui, user_id_state, cache_output)
|
| 263 |
|
| 264 |
if __name__ == "__main__":
|
| 265 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True, pwa=True)
|