Spaces:
Running
Running
| import sys | |
| import os | |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from sentence_transformers import SentenceTransformer | |
| from cache.cache import SemanticCache, CacheEntry | |
| # TEST QUERY PAIRS | |
| # Each tuple: (query_to_cache, incoming_query, should_hit: bool, description) | |
| TEST_PAIRS = [ | |
| # --- Should be cache HITS (same meaning, different words) --- | |
| ( | |
| "What are NASA's latest space missions?", | |
| "Recent space exploration news from NASA", | |
| True, | |
| "Paraphrase: space missions" | |
| ), | |
| ( | |
| "How does encryption protect data?", | |
| "What is the role of cryptography in security?", | |
| True, | |
| "Paraphrase: encryption/cryptography" | |
| ), | |
| ( | |
| "Best hockey teams this season", | |
| "Top performing NHL teams right now", | |
| True, | |
| "Paraphrase: hockey teams" | |
| ), | |
| ( | |
| "Is there a God? Religious perspectives", | |
| "Arguments for and against the existence of God", | |
| True, | |
| "Paraphrase: religion/God" | |
| ), | |
| ( | |
| "How to fix Windows file system errors?", | |
| "Troubleshooting Windows disk problems", | |
| True, | |
| "Paraphrase: Windows issues" | |
| ), | |
| # --- Should be cache MISSES (different topics) --- | |
| ( | |
| "NASA space shuttle launch schedule", | |
| "How does gun control legislation work?", | |
| False, | |
| "Different topics: space vs politics" | |
| ), | |
| ( | |
| "Best baseball teams of all time", | |
| "Jewish history in the Middle East", | |
| False, | |
| "Different topics: sports vs history" | |
| ), | |
| ( | |
| "How to configure SCSI drives on Mac?", | |
| "What does the Bible say about forgiveness?", | |
| False, | |
| "Different topics: hardware vs religion" | |
| ), | |
| ] | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # RUN ANALYSIS | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_threshold_analysis(): | |
| print("=" * 55) | |
| print(" Threshold Analysis") | |
| print("=" * 55) | |
| # Load embedding model | |
| print("\nπ€ Loading embedding model...") | |
| model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
| # Embed all queries | |
| print("βοΈ Embedding test queries...") | |
| all_queries = [p[0] for p in TEST_PAIRS] + [p[1] for p in TEST_PAIRS] | |
| all_embeddings = model.encode( | |
| all_queries, | |
| normalize_embeddings=True, | |
| convert_to_numpy=True | |
| ) | |
| cached_embeddings = all_embeddings[:len(TEST_PAIRS)] | |
| incoming_embeddings = all_embeddings[len(TEST_PAIRS):] | |
| # Compute actual cosine similarities | |
| similarities = [ | |
| float(np.dot(cached_embeddings[i], incoming_embeddings[i])) | |
| for i in range(len(TEST_PAIRS)) | |
| ] | |
| print("\nπ Actual cosine similarities between query pairs:") | |
| print("-" * 55) | |
| for i, (pair, sim) in enumerate(zip(TEST_PAIRS, similarities)): | |
| label = "β SHOULD HIT" if pair[2] else "β SHOULD MISS" | |
| print(f" {label} sim={sim:.3f} {pair[3]}") | |
| # Test across thresholds | |
| thresholds = np.arange(0.50, 1.00, 0.05) | |
| hit_rates = [] | |
| precisions = [] | |
| print("\nπ Testing thresholds from 0.50 to 0.99...") | |
| print("-" * 55) | |
| for thresh in thresholds: | |
| hits = 0 | |
| correct_hits = 0 | |
| total_should_hit = sum(1 for p in TEST_PAIRS if p[2]) | |
| for i, (pair, sim) in enumerate(zip(TEST_PAIRS, similarities)): | |
| is_hit = sim >= thresh | |
| if is_hit: | |
| hits += 1 | |
| if pair[2]: # correctly identified as hit | |
| correct_hits += 1 | |
| hit_rate = hits / len(TEST_PAIRS) | |
| precision = correct_hits / hits if hits > 0 else 0.0 | |
| hit_rates.append(hit_rate) | |
| precisions.append(precision) | |
| print(f" threshold={thresh:.2f} β " | |
| f"hit_rate={hit_rate:.2f} " | |
| f"precision={precision:.2f}") | |
| # βββββββββββββββββββββββββββββββββββββββββ | |
| # PLOT | |
| # βββββββββββββββββββββββββββββββββββββββββ | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) | |
| # Left: Hit Rate vs Threshold | |
| ax1.plot(thresholds, hit_rates, 'bo-', linewidth=2, markersize=6) | |
| ax1.axvline(x=0.85, color='red', linestyle='--', label='Recommended (0.85)') | |
| ax1.set_xlabel('Similarity Threshold') | |
| ax1.set_ylabel('Hit Rate') | |
| ax1.set_title('Cache Hit Rate vs Threshold\n(higher = more aggressive caching)') | |
| ax1.legend() | |
| ax1.grid(True, alpha=0.3) | |
| ax1.set_ylim(0, 1.05) | |
| # Right: Precision vs Threshold | |
| ax2.plot(thresholds, precisions, 'gs-', linewidth=2, markersize=6) | |
| ax2.axvline(x=0.85, color='red', linestyle='--', label='Recommended (0.85)') | |
| ax2.set_xlabel('Similarity Threshold') | |
| ax2.set_ylabel('Precision') | |
| ax2.set_title('Cache Precision vs Threshold\n(higher = fewer wrong answers)') | |
| ax2.legend() | |
| ax2.grid(True, alpha=0.3) | |
| ax2.set_ylim(0, 1.05) | |
| plt.suptitle( | |
| 'Threshold Analysis: Hit Rate vs Precision Tradeoff\n' | |
| 'Low threshold = more hits but wrong answers. ' | |
| 'High threshold = accurate but useless cache.', | |
| fontsize=10 | |
| ) | |
| plt.tight_layout() | |
| os.makedirs('./cache', exist_ok=True) | |
| plt.savefig('./cache/threshold_analysis.png', dpi=150) | |
| plt.close() | |
| print(f"\nβ Plot saved to ./cache/threshold_analysis.png") | |
| # βββββββββββββββββββββββββββββββββββββββββ | |
| # INTERPRETATION | |
| # βββββββββββββββββββββββββββββββββββββββββ | |
| print("\n\nπ‘ What Each Threshold Reveals:") | |
| print("-" * 55) | |
| print(" 0.50 β Catches almost everything, including wrong matches") | |
| print(" High hit rate but low precision = wrong answers returned") | |
| print(" 0.70 β Better precision but still some false positives") | |
| print(" 0.85 β Sweet spot: paraphrases match, distinct queries don't") | |
| print(" This is our recommended default") | |
| print(" 0.95 β Very strict: only near-identical queries match") | |
| print(" Cache barely ever triggers = defeats the purpose") | |
| print(" 0.99 β Almost never hits = cache is useless at this level") | |
| print("\nπ Threshold analysis complete!") | |
| if __name__ == "__main__": | |
| run_threshold_analysis() |