Spaces:
Running
Running
File size: 6,765 Bytes
eec9162 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import numpy as np
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer
from cache.cache import SemanticCache, CacheEntry
# TEST QUERY PAIRS
# Each tuple: (query_to_cache, incoming_query, should_hit: bool, description)
TEST_PAIRS = [
# --- Should be cache HITS (same meaning, different words) ---
(
"What are NASA's latest space missions?",
"Recent space exploration news from NASA",
True,
"Paraphrase: space missions"
),
(
"How does encryption protect data?",
"What is the role of cryptography in security?",
True,
"Paraphrase: encryption/cryptography"
),
(
"Best hockey teams this season",
"Top performing NHL teams right now",
True,
"Paraphrase: hockey teams"
),
(
"Is there a God? Religious perspectives",
"Arguments for and against the existence of God",
True,
"Paraphrase: religion/God"
),
(
"How to fix Windows file system errors?",
"Troubleshooting Windows disk problems",
True,
"Paraphrase: Windows issues"
),
# --- Should be cache MISSES (different topics) ---
(
"NASA space shuttle launch schedule",
"How does gun control legislation work?",
False,
"Different topics: space vs politics"
),
(
"Best baseball teams of all time",
"Jewish history in the Middle East",
False,
"Different topics: sports vs history"
),
(
"How to configure SCSI drives on Mac?",
"What does the Bible say about forgiveness?",
False,
"Different topics: hardware vs religion"
),
]
# βββββββββββββββββββββββββββββββββββββββββββββ
# RUN ANALYSIS
# βββββββββββββββββββββββββββββββββββββββββββββ
def run_threshold_analysis():
print("=" * 55)
print(" Threshold Analysis")
print("=" * 55)
# Load embedding model
print("\nπ€ Loading embedding model...")
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# Embed all queries
print("βοΈ Embedding test queries...")
all_queries = [p[0] for p in TEST_PAIRS] + [p[1] for p in TEST_PAIRS]
all_embeddings = model.encode(
all_queries,
normalize_embeddings=True,
convert_to_numpy=True
)
cached_embeddings = all_embeddings[:len(TEST_PAIRS)]
incoming_embeddings = all_embeddings[len(TEST_PAIRS):]
# Compute actual cosine similarities
similarities = [
float(np.dot(cached_embeddings[i], incoming_embeddings[i]))
for i in range(len(TEST_PAIRS))
]
print("\nπ Actual cosine similarities between query pairs:")
print("-" * 55)
for i, (pair, sim) in enumerate(zip(TEST_PAIRS, similarities)):
label = "β
SHOULD HIT" if pair[2] else "β SHOULD MISS"
print(f" {label} sim={sim:.3f} {pair[3]}")
# Test across thresholds
thresholds = np.arange(0.50, 1.00, 0.05)
hit_rates = []
precisions = []
print("\nπ Testing thresholds from 0.50 to 0.99...")
print("-" * 55)
for thresh in thresholds:
hits = 0
correct_hits = 0
total_should_hit = sum(1 for p in TEST_PAIRS if p[2])
for i, (pair, sim) in enumerate(zip(TEST_PAIRS, similarities)):
is_hit = sim >= thresh
if is_hit:
hits += 1
if pair[2]: # correctly identified as hit
correct_hits += 1
hit_rate = hits / len(TEST_PAIRS)
precision = correct_hits / hits if hits > 0 else 0.0
hit_rates.append(hit_rate)
precisions.append(precision)
print(f" threshold={thresh:.2f} β "
f"hit_rate={hit_rate:.2f} "
f"precision={precision:.2f}")
# βββββββββββββββββββββββββββββββββββββββββ
# PLOT
# βββββββββββββββββββββββββββββββββββββββββ
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
# Left: Hit Rate vs Threshold
ax1.plot(thresholds, hit_rates, 'bo-', linewidth=2, markersize=6)
ax1.axvline(x=0.85, color='red', linestyle='--', label='Recommended (0.85)')
ax1.set_xlabel('Similarity Threshold')
ax1.set_ylabel('Hit Rate')
ax1.set_title('Cache Hit Rate vs Threshold\n(higher = more aggressive caching)')
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_ylim(0, 1.05)
# Right: Precision vs Threshold
ax2.plot(thresholds, precisions, 'gs-', linewidth=2, markersize=6)
ax2.axvline(x=0.85, color='red', linestyle='--', label='Recommended (0.85)')
ax2.set_xlabel('Similarity Threshold')
ax2.set_ylabel('Precision')
ax2.set_title('Cache Precision vs Threshold\n(higher = fewer wrong answers)')
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_ylim(0, 1.05)
plt.suptitle(
'Threshold Analysis: Hit Rate vs Precision Tradeoff\n'
'Low threshold = more hits but wrong answers. '
'High threshold = accurate but useless cache.',
fontsize=10
)
plt.tight_layout()
os.makedirs('./cache', exist_ok=True)
plt.savefig('./cache/threshold_analysis.png', dpi=150)
plt.close()
print(f"\nβ
Plot saved to ./cache/threshold_analysis.png")
# βββββββββββββββββββββββββββββββββββββββββ
# INTERPRETATION
# βββββββββββββββββββββββββββββββββββββββββ
print("\n\nπ‘ What Each Threshold Reveals:")
print("-" * 55)
print(" 0.50 β Catches almost everything, including wrong matches")
print(" High hit rate but low precision = wrong answers returned")
print(" 0.70 β Better precision but still some false positives")
print(" 0.85 β Sweet spot: paraphrases match, distinct queries don't")
print(" This is our recommended default")
print(" 0.95 β Very strict: only near-identical queries match")
print(" Cache barely ever triggers = defeats the purpose")
print(" 0.99 β Almost never hits = cache is useless at this level")
print("\nπ Threshold analysis complete!")
if __name__ == "__main__":
run_threshold_analysis() |