Spaces:
Sleeping
Sleeping
| """ | |
| Cache Manager - Gère Hit/Miss et distillation locale | |
| """ | |
| import numpy as np | |
| from typing import Dict, List, Any | |
| import uuid | |
| from datetime import datetime | |
| from config import DISTANCE_THRESHOLD, TOP_K_RESULTS, CONFIDENCE_THRESHOLD_WARNING | |
| class CacheManager: | |
| def __init__(self, chroma_collection, encoder_fn, threshold=None): | |
| """ | |
| Args: | |
| chroma_collection: Collection ChromaDB | |
| encoder_fn: Fonction pour encoder du texte en embedding | |
| threshold: Custom similarity threshold | |
| """ | |
| self.collection = chroma_collection | |
| self.encoder_fn = encoder_fn | |
| self.threshold = threshold if threshold is not None else DISTANCE_THRESHOLD | |
| def calculate_confidence(self, distances: List[float]) -> float: | |
| """Convertit la distance Chroma (Cosine) en score de confiance [0, 1].""" | |
| if not distances: | |
| return 0.0 | |
| # Avec hnsw:space="cosine", distance = 1 - similarity. | |
| avg_distance = np.mean(distances) | |
| return max(0.0, 1.0 - avg_distance) | |
| def query_cache(self, code: str, context: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Recherche dans le cache (Pipeline Hybride: Exact Match -> Vector Search -> Code Comparison) | |
| """ | |
| # --- NIVEAU 1 : CHECK RAPIDE (String Exact Match) --- | |
| try: | |
| if len(code) < 5000: | |
| exact_matches = self.collection.get(where={"code": code}, limit=1) | |
| if exact_matches and len(exact_matches['ids']) > 0: | |
| return { | |
| "status": "perfect_match", | |
| "results": [{ | |
| "feedback": exact_matches['documents'][0], | |
| "code": code, | |
| "distance": 0.0, | |
| "rank": 1, | |
| "metadata": exact_matches['metadatas'][0] | |
| }], | |
| "confidence": 1.0, | |
| "needs_warning": False, | |
| "closest_distance": 0.0 | |
| } | |
| except Exception as e: | |
| print(f"Warning exact match check: {e}") | |
| # --- NIVEAU 2 : RETRIEVAL (Vectorielle) --- | |
| query_embedding = self.encoder_fn(code) | |
| # On récupère les candidats (basé sur la proximité Code Input -> Feedback Embedding) | |
| query_results = self.collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=TOP_K_RESULTS | |
| ) | |
| distances = query_results['distances'][0] if query_results['distances'] else [] | |
| documents = query_results['documents'][0] if query_results['documents'] else [] | |
| metadatas = query_results['metadatas'][0] if query_results['metadatas'] else [] | |
| # --- NIVEAU 3 : ANALYSE FINE (Code vs Code) --- | |
| is_code_hit = False | |
| code_distance = 1.0 # Pire cas par défaut | |
| # On vérifie si le code du meilleur candidat est sémantiquement proche du code utilisateur | |
| if metadatas and metadatas[0].get('code'): | |
| ref_code = metadatas[0].get('code') | |
| if ref_code and ref_code != 'N/A': | |
| # On encode le code de référence pour comparer avec le code d'entrée | |
| ref_code_embedding = self.encoder_fn(ref_code) | |
| # Distance Cosine entre les deux codes | |
| # Note: np.dot sur vecteurs normalisés = Cosine Similarity. Distance = 1 - Sim. | |
| similarity = float(np.dot(query_embedding, ref_code_embedding)) | |
| code_distance = max(0.0, 1.0 - similarity) | |
| # Seuil très strict pour dire "C'est le même code" (mais écrit différemment) | |
| if code_distance < 0.1: # Correspond à > 90% de similarité | |
| is_code_hit = True | |
| # --- DÉCISION FINALE --- | |
| is_hit = False | |
| hit_type = "miss" | |
| # Priorité 1 : Code quasi-identique vectoriellement | |
| if is_code_hit: | |
| is_hit = True | |
| hit_type = "code_hit" | |
| # Priorité 2 : Feedback pertinent (Standard RAG) selon le slider | |
| elif distances and distances[0] < self.threshold: | |
| is_hit = True | |
| hit_type = "feedback_hit" | |
| # Formatage des résultats pour l'affichage | |
| formatted_results = [] | |
| for i, (feedback, metadata, dist) in enumerate(zip(documents, metadatas, distances)): | |
| formatted_results.append({ | |
| "rank": i + 1, | |
| "feedback": feedback, | |
| "code": metadata.get('code', 'N/A'), | |
| "distance": round(dist, 4), | |
| "metadata": metadata | |
| }) | |
| if is_hit: | |
| confidence = self.calculate_confidence(distances) | |
| # Boost de confiance si c'est un code hit | |
| if hit_type == "code_hit": | |
| confidence = max(confidence, 0.95) | |
| return { | |
| "status": hit_type, | |
| "results": formatted_results, | |
| "confidence": round(confidence, 3), | |
| "needs_warning": False if hit_type == "code_hit" else (confidence < CONFIDENCE_THRESHOLD_WARNING), | |
| "closest_distance": round(distances[0], 4) | |
| } | |
| else: | |
| return { | |
| "status": "miss", | |
| "results": formatted_results, | |
| "confidence": 0.0, | |
| "needs_warning": False, | |
| "closest_distance": round(distances[0], 4) if distances else 1.0 | |
| } | |
| def add_to_cache(self, code: str, feedback: str, metadata: Dict[str, Any], embedding: List[float]) -> bool: | |
| """ | |
| Ajoute au cache local pour la session courante (Active Learning). | |
| """ | |
| try: | |
| doc_id = f"learned_{uuid.uuid4().hex[:8]}" | |
| safe_metadata = { | |
| "code": code[:10000], | |
| "timestamp": datetime.now().isoformat(), | |
| "source": "active_learning", | |
| "theme": str(metadata.get("theme", "")), | |
| "difficulty": str(metadata.get("difficulty", "")) | |
| } | |
| self.collection.add( | |
| embeddings=[embedding], | |
| documents=[feedback], | |
| metadatas=[safe_metadata], | |
| ids=[doc_id] | |
| ) | |
| print(f"✅ Learned new feedback: {doc_id}") | |
| return True | |
| except Exception as e: | |
| print(f"❌ Error adding to cache: {e}") | |
| return False |