PraneshJs commited on
Commit
cc33aed
·
verified ·
1 Parent(s): 5d8a695

imporved correct embedding

Browse files
Files changed (1) hide show
  1. app.py +14 -3
app.py CHANGED
@@ -39,7 +39,7 @@ def get_embedding(text):
39
  def cosine_similarity(vec1, vec2):
40
  return float(np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)))
41
 
42
- def search_cache(user_id, user_input, threshold=0.8):
43
  query_vec = get_embedding(user_input)
44
  best_key, best_score, best_val = None, -1, None
45
  cache_key = f"cache:{user_id}"
@@ -58,11 +58,22 @@ def search_cache(user_id, user_input, threshold=0.8):
58
  def store_cache(user_id, user_input, output):
59
  vec = get_embedding(user_input).tolist()
60
  cache_key = f"cache:{user_id}"
61
- redis_client.hset(cache_key, user_input, json.dumps({
 
 
 
 
 
 
 
 
 
 
 
 
62
  "embedding": vec,
63
  "output": output
64
  }))
65
- # ⏳ Reset TTL every time user stores something
66
  redis_client.expire(cache_key, CACHE_TTL)
67
 
68
  def clear_user_cache(user_id):
 
39
  def cosine_similarity(vec1, vec2):
40
  return float(np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)))
41
 
42
+ def search_cache(user_id, user_input, threshold=0.9): # stricter threshold
43
  query_vec = get_embedding(user_input)
44
  best_key, best_score, best_val = None, -1, None
45
  cache_key = f"cache:{user_id}"
 
58
  def store_cache(user_id, user_input, output):
59
  vec = get_embedding(user_input).tolist()
60
  cache_key = f"cache:{user_id}"
61
+
62
+ # Store extra context: include language keyword if present
63
+ context_input = user_input.lower()
64
+ if "java" in context_input:
65
+ context_input = "JAVA: " + context_input
66
+ elif "python" in context_input:
67
+ context_input = "PYTHON: " + context_input
68
+ elif "c++" in context_input or "cpp" in context_input:
69
+ context_input = "CPP: " + context_input
70
+ elif "c " in context_input:
71
+ context_input = "C: " + context_input
72
+
73
+ redis_client.hset(cache_key, context_input, json.dumps({
74
  "embedding": vec,
75
  "output": output
76
  }))
 
77
  redis_client.expire(cache_key, CACHE_TTL)
78
 
79
  def clear_user_cache(user_id):