Spaces:
Sleeping
Sleeping
Matis Codjia commited on
Commit ·
3ee62c8
1
Parent(s): 3a00bdd
Auto load
Browse files- app.py +131 -459
- cache_manager.py +43 -121
app.py
CHANGED
|
@@ -1,42 +1,40 @@
|
|
| 1 |
"""
|
| 2 |
-
Streamlit RAG Viewer avec Cache Intelligent
|
| 3 |
"""
|
| 4 |
|
| 5 |
import streamlit as st
|
| 6 |
import torch
|
| 7 |
import torch.nn.functional as F
|
| 8 |
from transformers import AutoTokenizer, AutoModel
|
| 9 |
-
from datasets import load_dataset
|
| 10 |
import chromadb
|
| 11 |
from pathlib import Path
|
| 12 |
import json
|
| 13 |
import time
|
| 14 |
import logging
|
| 15 |
import sys
|
|
|
|
|
|
|
|
|
|
| 16 |
# Import des modules custom
|
| 17 |
from cache_manager import CacheManager
|
| 18 |
from deepseek_caller import DeepSeekCaller
|
| 19 |
from stats_logger import StatsLogger
|
| 20 |
from config import DISTANCE_THRESHOLD
|
| 21 |
from utils import load_css
|
| 22 |
-
from huggingface_hub import login, snapshot_download
|
| 23 |
-
import os
|
| 24 |
|
| 25 |
# ==========================================
|
| 26 |
# PAGE CONFIG
|
| 27 |
# ==========================================
|
| 28 |
st.set_page_config(
|
| 29 |
page_title="RAG Feedback System",
|
| 30 |
-
page_icon="",
|
| 31 |
layout="wide",
|
| 32 |
initial_sidebar_state="expanded"
|
| 33 |
)
|
| 34 |
|
|
|
|
| 35 |
DATASET_ID = "matis35/chroma-rag-storage"
|
| 36 |
-
REPO_FOLDER = "chroma_db_storage"
|
| 37 |
-
|
| 38 |
-
# Le dossier local où Streamlit va stocker la DB
|
| 39 |
-
# On se met un niveau au-dessus pour que snapshot_download recrée le dossier "chroma_db_storage" dedans
|
| 40 |
LOCAL_CACHE_DIR = Path("./chroma_cache")
|
| 41 |
|
| 42 |
# ==========================================
|
|
@@ -48,45 +46,38 @@ load_css("assets/style.css")
|
|
| 48 |
# STATE MANAGEMENT
|
| 49 |
# ==========================================
|
| 50 |
if 'model_loaded' not in st.session_state: st.session_state.model_loaded = False
|
| 51 |
-
if 'dataset_loaded' not in st.session_state: st.session_state.dataset_loaded = False
|
| 52 |
if 'db_initialized' not in st.session_state: st.session_state.db_initialized = False
|
| 53 |
if 'cache_manager' not in st.session_state: st.session_state.cache_manager = None
|
| 54 |
if 'deepseek_caller' not in st.session_state: st.session_state.deepseek_caller = None
|
| 55 |
if 'stats_logger' not in st.session_state: st.session_state.stats_logger = StatsLogger()
|
| 56 |
|
| 57 |
# ==========================================
|
| 58 |
-
#
|
| 59 |
# ==========================================
|
| 60 |
logging.basicConfig(
|
| 61 |
level=logging.INFO,
|
| 62 |
format='%(asctime)s | %(levelname)s | %(message)s',
|
| 63 |
datefmt='%H:%M:%S',
|
| 64 |
-
handlers=[
|
| 65 |
-
logging.StreamHandler(sys.stdout)
|
| 66 |
-
]
|
| 67 |
)
|
| 68 |
-
|
| 69 |
logger = logging.getLogger("FFGen_System")
|
|
|
|
|
|
|
| 70 |
hf_token = os.environ.get("HF_TOKEN")
|
|
|
|
|
|
|
| 71 |
|
| 72 |
if hf_token:
|
| 73 |
-
# Se connecte explicitement
|
| 74 |
login(token=hf_token)
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
print("Connected via st.secrets")
|
| 81 |
-
else:
|
| 82 |
-
print("No HF key found")
|
| 83 |
-
except FileNotFoundError:
|
| 84 |
-
print("Local execution without secrets")
|
| 85 |
@st.cache_resource
|
| 86 |
def load_full_model(model_path: str):
|
| 87 |
-
"""
|
| 88 |
-
st.info(f"Loading model from: {model_path}")
|
| 89 |
-
logger.info(f" Loading from: {model_path}...")
|
| 90 |
try:
|
| 91 |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 92 |
if tokenizer.pad_token is None:
|
|
@@ -97,196 +88,110 @@ def load_full_model(model_path: str):
|
|
| 97 |
trust_remote_code=True,
|
| 98 |
device_map="auto"
|
| 99 |
)
|
| 100 |
-
logger.info(f"Modèle chargé avec succès !")
|
| 101 |
model.eval()
|
| 102 |
return model, tokenizer
|
| 103 |
except Exception as e:
|
| 104 |
-
st.error(f"
|
| 105 |
-
logger.error("Echec du chargement du modèle !")
|
| 106 |
return None, None
|
| 107 |
|
| 108 |
def encode_text(text: str, model, tokenizer):
|
| 109 |
-
"""
|
| 110 |
device = next(model.parameters()).device
|
| 111 |
-
|
| 112 |
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True)
|
| 113 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 114 |
-
|
| 115 |
with torch.no_grad():
|
| 116 |
outputs = model(**inputs)
|
| 117 |
embeddings = outputs.last_hidden_state.mean(dim=1)
|
| 118 |
embeddings = F.normalize(embeddings, p=2, dim=1)
|
| 119 |
-
|
| 120 |
return embeddings[0].cpu().numpy().tolist()
|
| 121 |
|
| 122 |
-
@st.
|
| 123 |
-
def load_dataset_from_source(source: str, path: str):
|
| 124 |
-
logger.info(f"Source séléctionnée {source}")
|
| 125 |
-
if source == "HuggingFace Hub":
|
| 126 |
-
|
| 127 |
-
dataset = load_dataset(path)
|
| 128 |
-
data = []
|
| 129 |
-
for split in dataset.keys():
|
| 130 |
-
data.extend(dataset[split].to_list())
|
| 131 |
-
return data
|
| 132 |
-
else:
|
| 133 |
-
data = []
|
| 134 |
-
with open(path, 'r') as f:
|
| 135 |
-
for line in f:
|
| 136 |
-
if line.strip():
|
| 137 |
-
data.append(json.loads(line))
|
| 138 |
-
return data
|
| 139 |
-
|
| 140 |
-
# @st.cache_resource # <--- Décommenter si tu es sous Streamlit pour ne le faire qu'une fois !
|
| 141 |
def initialize_chromadb():
|
| 142 |
"""
|
| 143 |
-
|
|
|
|
| 144 |
"""
|
| 145 |
-
|
| 146 |
-
# 1. CHEMIN CIBLE
|
| 147 |
-
# Le chemin final sera : ./chroma_cache/chroma_db_storage
|
| 148 |
final_db_path = LOCAL_CACHE_DIR / REPO_FOLDER
|
| 149 |
|
| 150 |
-
#
|
| 151 |
if not final_db_path.exists():
|
| 152 |
-
print(f"📥
|
| 153 |
try:
|
| 154 |
snapshot_download(
|
| 155 |
repo_id=DATASET_ID,
|
| 156 |
repo_type="dataset",
|
| 157 |
-
local_dir=LOCAL_CACHE_DIR,
|
| 158 |
-
allow_patterns=[f"{REPO_FOLDER}/*"],
|
| 159 |
-
local_dir_use_symlinks=False
|
| 160 |
-
# token=os.environ.get("HF_TOKEN") # Nécessaire si le dataset est PRIVÉ
|
| 161 |
)
|
| 162 |
-
print("
|
| 163 |
except Exception as e:
|
| 164 |
-
|
| 165 |
-
# Fallback : Si on est en local et que le dossier existe déjà ailleurs, on pourrait pointer dessus
|
| 166 |
raise e
|
| 167 |
|
| 168 |
-
#
|
| 169 |
-
|
| 170 |
client = chromadb.PersistentClient(path=str(final_db_path))
|
| 171 |
|
| 172 |
-
#
|
| 173 |
-
# Attention : On ne fait plus de "create_collection" ni de "delete".
|
| 174 |
-
# On récupère juste ce qui existe.
|
| 175 |
try:
|
| 176 |
collection = client.get_collection(name="feedbacks")
|
| 177 |
-
print(f"📊 Collection
|
| 178 |
except Exception as e:
|
| 179 |
-
|
| 180 |
raise e
|
| 181 |
|
| 182 |
return client, collection
|
|
|
|
| 183 |
# ==========================================
|
| 184 |
-
# MAIN
|
| 185 |
# ==========================================
|
| 186 |
|
| 187 |
st.title("FFGEN")
|
| 188 |
st.markdown("### Submit code and get instant feedback")
|
| 189 |
|
| 190 |
-
#
|
| 191 |
-
# SIDEBAR - CONFIGURATION
|
| 192 |
-
# ==========================================
|
| 193 |
-
|
| 194 |
with st.sidebar:
|
| 195 |
-
st.header(" Configuration")
|
| 196 |
-
|
| 197 |
-
# --- MODEL SELECTION ---
|
| 198 |
-
st.subheader("Embedding Model")
|
| 199 |
-
model_path = st.text_input(
|
| 200 |
-
"Model Path (Local or HF)",
|
| 201 |
-
value="matis35/gemmaembedding-fgdor",
|
| 202 |
-
help="Path to embedding model"
|
| 203 |
-
)
|
| 204 |
-
|
| 205 |
-
# --- DATASET SELECTION ---
|
| 206 |
-
st.subheader("Dataset")
|
| 207 |
-
data_source = st.selectbox("Source", ["HuggingFace Hub", "Local JSONL"])
|
| 208 |
-
dataset_path = st.text_input("Dataset Path", value="matis35/SYNT_V4")
|
| 209 |
|
|
|
|
|
|
|
|
|
|
| 210 |
st.divider()
|
| 211 |
-
|
| 212 |
-
#
|
| 213 |
-
st.subheader("Cache
|
| 214 |
-
|
| 215 |
-
# Permettre de modifier le threshold dynamiquement
|
| 216 |
if 'custom_threshold' not in st.session_state:
|
| 217 |
st.session_state.custom_threshold = DISTANCE_THRESHOLD
|
| 218 |
|
| 219 |
custom_threshold = st.slider(
|
| 220 |
-
"
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
value=st.session_state.custom_threshold,
|
| 224 |
-
step=0.05,
|
| 225 |
-
help="Distance < threshold = HIT. Modifier cette valeur change le comportement du cache sans réindexer."
|
| 226 |
)
|
| 227 |
-
|
| 228 |
if custom_threshold != st.session_state.custom_threshold:
|
| 229 |
st.session_state.custom_threshold = custom_threshold
|
| 230 |
-
# Mettre à jour le threshold du cache manager existant si disponible
|
| 231 |
if st.session_state.get('cache_manager'):
|
| 232 |
st.session_state.cache_manager.threshold = custom_threshold
|
| 233 |
-
st.info(f"Threshold updated to {custom_threshold:.2f}")
|
| 234 |
-
|
| 235 |
-
st.caption(f"Current: Distance < {st.session_state.custom_threshold:.2f} = HIT")
|
| 236 |
|
| 237 |
st.divider()
|
| 238 |
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
-
|
| 242 |
-
with col1:
|
| 243 |
-
load_btn = st.button("Load & Index", use_container_width=True)
|
| 244 |
-
with col2:
|
| 245 |
-
use_cached_btn = st.button(" Use Cached", use_container_width=True)
|
| 246 |
|
| 247 |
-
#
|
| 248 |
-
|
| 249 |
-
try:
|
| 250 |
-
client, collection = initialize_chromadb(force_reindex=False)
|
| 251 |
-
count = collection.count()
|
| 252 |
-
if count > 0:
|
| 253 |
-
st.session_state.client = client
|
| 254 |
-
st.session_state.collection = collection
|
| 255 |
-
st.session_state.db_initialized = True
|
| 256 |
-
st.success(f"DB Loaded: {count} docs")
|
| 257 |
-
logger.info(f"Base de données démarrée avec succès: {count} instances")
|
| 258 |
-
if not st.session_state.model_loaded:
|
| 259 |
-
model, tokenizer = load_full_model(model_path)
|
| 260 |
-
if model:
|
| 261 |
-
st.session_state.model = model
|
| 262 |
-
st.session_state.tokenizer = tokenizer
|
| 263 |
-
st.session_state.model_loaded = True
|
| 264 |
-
|
| 265 |
-
# Initialiser cache manager avec threshold dynamique
|
| 266 |
-
encoder_fn = lambda text: encode_text(text, model, tokenizer)
|
| 267 |
-
st.session_state.cache_manager = CacheManager(
|
| 268 |
-
collection,
|
| 269 |
-
encoder_fn,
|
| 270 |
-
threshold=st.session_state.custom_threshold
|
| 271 |
-
)
|
| 272 |
-
|
| 273 |
-
# Initialiser DeepSeek caller
|
| 274 |
-
try:
|
| 275 |
-
st.session_state.deepseek_caller = DeepSeekCaller()
|
| 276 |
-
st.success(" DeepSeek API Ready")
|
| 277 |
-
logger.info("API prête")
|
| 278 |
-
except Exception as e:
|
| 279 |
-
st.warning(f" DeepSeek API unavailable: {e}")
|
| 280 |
-
logger.error(f"API non disponible: {e}")
|
| 281 |
-
else:
|
| 282 |
-
st.warning(" Empty DB. Please Load & Index first.")
|
| 283 |
-
except Exception as e:
|
| 284 |
-
st.error(f"Error: {e}")
|
| 285 |
-
logger.error(f"Problème avec la base de données: {e}")
|
| 286 |
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
with st.spinner("Loading Model..."):
|
| 290 |
model, tokenizer = load_full_model(model_path)
|
| 291 |
if model:
|
| 292 |
st.session_state.model = model
|
|
@@ -295,340 +200,107 @@ with st.sidebar:
|
|
| 295 |
else:
|
| 296 |
st.stop()
|
| 297 |
|
| 298 |
-
|
| 299 |
-
|
| 300 |
try:
|
| 301 |
-
|
| 302 |
-
st.session_state.dataset = data
|
| 303 |
-
st.session_state.dataset_loaded = True
|
| 304 |
-
except Exception as e:
|
| 305 |
-
st.error(f"Dataset Error: {e}")
|
| 306 |
-
logger.error("Problème de chargement du dataset")
|
| 307 |
-
st.stop()
|
| 308 |
-
|
| 309 |
-
if st.session_state.dataset_loaded:
|
| 310 |
-
with st.spinner(f"Indexing {len(data)} items..."):
|
| 311 |
-
client, collection = initialize_chromadb(force_reindex=force_reindex)
|
| 312 |
-
|
| 313 |
-
batch_size = 64
|
| 314 |
-
progress_bar = st.progress(0)
|
| 315 |
-
|
| 316 |
-
for i in range(0, len(data), batch_size):
|
| 317 |
-
batch = data[i:i+batch_size]
|
| 318 |
-
|
| 319 |
-
feedbacks = [item.get("feedback", item.get("generated_feedback", "")) for item in batch]
|
| 320 |
-
codes = [item.get("code") for item in batch]
|
| 321 |
-
|
| 322 |
-
# IMPORTANT: Encode FEEDBACK for bi-encoder retrieval (code→feedback)
|
| 323 |
-
embeddings = [encode_text(fb, model, tokenizer) for fb in feedbacks]
|
| 324 |
-
|
| 325 |
-
# Store code as metadata for later comparison
|
| 326 |
-
metadatas = [{"code": c if c else ""} for c in codes]
|
| 327 |
-
ids = [f"id_{i+j}" for j in range(len(batch))]
|
| 328 |
-
|
| 329 |
-
collection.add(
|
| 330 |
-
embeddings=embeddings,
|
| 331 |
-
documents=feedbacks,
|
| 332 |
-
metadatas=metadatas,
|
| 333 |
-
ids=ids
|
| 334 |
-
)
|
| 335 |
-
progress_bar.progress(min(1.0, (i + batch_size) / len(data)))
|
| 336 |
-
|
| 337 |
st.session_state.client = client
|
| 338 |
st.session_state.collection = collection
|
| 339 |
st.session_state.db_initialized = True
|
| 340 |
-
|
| 341 |
-
#
|
| 342 |
encoder_fn = lambda text: encode_text(text, model, tokenizer)
|
| 343 |
st.session_state.cache_manager = CacheManager(
|
| 344 |
collection,
|
| 345 |
encoder_fn,
|
| 346 |
threshold=st.session_state.custom_threshold
|
| 347 |
)
|
| 348 |
-
|
| 349 |
-
#
|
| 350 |
try:
|
| 351 |
st.session_state.deepseek_caller = DeepSeekCaller()
|
| 352 |
except:
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
st.success("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
|
| 357 |
-
#
|
| 358 |
-
# MAIN INTERFACE - QUERY
|
| 359 |
-
# ==========================================
|
| 360 |
|
| 361 |
if st.session_state.db_initialized and st.session_state.cache_manager:
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
# Formulaire enrichi
|
| 366 |
with st.form("code_submission"):
|
| 367 |
col1, col2 = st.columns([2, 1])
|
| 368 |
-
|
| 369 |
-
with col1:
|
| 370 |
-
code_input = st.text_area(
|
| 371 |
-
"C Code",
|
| 372 |
-
height=300,
|
| 373 |
-
placeholder="Paste your C code here...",
|
| 374 |
-
help="The code you want feedback on"
|
| 375 |
-
)
|
| 376 |
-
|
| 377 |
-
with col2:
|
| 378 |
-
theme = st.text_input(
|
| 379 |
-
"Exercise Theme",
|
| 380 |
-
placeholder="e.g., Binary Search",
|
| 381 |
-
help="What is this exercise about?"
|
| 382 |
-
)
|
| 383 |
-
|
| 384 |
-
difficulty = st.selectbox(
|
| 385 |
-
"Difficulty Level",
|
| 386 |
-
["beginner", "intermediate", "advanced"]
|
| 387 |
-
)
|
| 388 |
-
|
| 389 |
-
error_category = st.text_input(
|
| 390 |
-
"Error Category (optional)",
|
| 391 |
-
placeholder="e.g., Off-by-one Error",
|
| 392 |
-
help="If you know the type of error"
|
| 393 |
-
)
|
| 394 |
-
|
| 395 |
-
instructions = st.text_area(
|
| 396 |
-
"Exercise Instructions (optional)",
|
| 397 |
-
placeholder="Describe what the function should do...",
|
| 398 |
-
help="Helps generate better feedback on cache miss"
|
| 399 |
-
)
|
| 400 |
-
|
| 401 |
-
col1, col2 = st.columns(2)
|
| 402 |
with col1:
|
| 403 |
-
|
| 404 |
-
"Test Cases Scope (optional)",
|
| 405 |
-
placeholder="e.g., Test with n=0, n=5, n=10",
|
| 406 |
-
help="What tests should pass"
|
| 407 |
-
)
|
| 408 |
-
|
| 409 |
with col2:
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
submit_btn = st.form_submit_button(" Search Feedback", use_container_width=True)
|
| 417 |
|
| 418 |
-
# TRAITEMENT DE LA REQUÊTE
|
| 419 |
if submit_btn and code_input:
|
| 420 |
start_time = time.time()
|
| 421 |
-
|
| 422 |
-
# Contexte complet
|
| 423 |
context = {
|
| 424 |
-
"code": code_input,
|
| 425 |
-
"
|
| 426 |
-
"
|
| 427 |
-
"error_category": error_category or "Unknown",
|
| 428 |
-
"instructions": instructions or "No instructions provided",
|
| 429 |
-
"test_cases_scope": [test_scope] if test_scope else [],
|
| 430 |
-
"failed_tests": [failed_tests] if failed_tests else []
|
| 431 |
}
|
| 432 |
|
| 433 |
-
# Query
|
| 434 |
-
with st.spinner(" Searching
|
| 435 |
cache_result = st.session_state.cache_manager.query_cache(code_input, context)
|
|
|
|
|
|
|
| 436 |
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
# CACHE HIT ou PERFECT MATCH
|
| 440 |
if cache_result['status'] in ['hit', 'perfect_match']:
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
st.markdown("### Cache HIT - Feedback from Database")
|
| 450 |
-
|
| 451 |
-
col1, col2, col3 = st.columns(3)
|
| 452 |
-
with col1:
|
| 453 |
-
st.metric("Confidence", f"{cache_result['confidence']:.2f}")
|
| 454 |
-
with col2:
|
| 455 |
-
st.metric("Best Match Distance (code→feedback)", f"{cache_result['similarity_scores'][0]:.4f}")
|
| 456 |
-
with col3:
|
| 457 |
-
st.metric("Response Time", f"{response_time:.0f} ms")
|
| 458 |
-
|
| 459 |
-
# Afficher code similarity si disponible
|
| 460 |
-
if cache_result.get('code_similarity') is not None:
|
| 461 |
-
st.metric("Code Similarity", f"{cache_result['code_similarity']:.4f}",
|
| 462 |
-
help="Similarity between your code and reference code (1.0 = identical)")
|
| 463 |
-
|
| 464 |
-
if cache_result['needs_warning'] and not is_perfect:
|
| 465 |
-
st.warning(" **Note:** Confidence is moderate. Review carefully.")
|
| 466 |
-
|
| 467 |
-
# Afficher les résultats
|
| 468 |
-
for result in cache_result['results']:
|
| 469 |
-
# Calculer distance code_soumis ↔ code_référence
|
| 470 |
-
code_ref = result['code']
|
| 471 |
-
if code_ref and code_ref != 'N/A':
|
| 472 |
-
code_ref_embedding = encode_text(code_ref, st.session_state.model, st.session_state.tokenizer)
|
| 473 |
-
code_submitted_embedding = encode_text(code_input, st.session_state.model, st.session_state.tokenizer)
|
| 474 |
-
|
| 475 |
-
# Cosine similarity
|
| 476 |
-
import numpy as np
|
| 477 |
-
similarity = np.dot(code_ref_embedding, code_submitted_embedding)
|
| 478 |
-
code_distance = 1 - similarity
|
| 479 |
-
else:
|
| 480 |
-
code_distance = None
|
| 481 |
-
|
| 482 |
-
with st.expander(f" Match #{result['rank']} (code→feedback distance: {result['distance']:.4f})"):
|
| 483 |
-
# Métriques côte à côte
|
| 484 |
-
col1, col2 = st.columns(2)
|
| 485 |
-
with col1:
|
| 486 |
-
st.metric("Code → Feedback", f"{result['distance']:.4f}", help="Distance entre votre code et ce feedback (apprentissage bi-encoder)")
|
| 487 |
-
with col2:
|
| 488 |
-
if code_distance is not None:
|
| 489 |
-
st.metric("Code → Code Ref", f"{code_distance:.4f}", help="Distance entre votre code et le code de référence pour ce feedback")
|
| 490 |
-
|
| 491 |
-
st.markdown("**Feedback:**")
|
| 492 |
-
st.write(result['feedback'])
|
| 493 |
-
|
| 494 |
-
st.markdown("**Reference Code (this feedback was given for):**")
|
| 495 |
-
st.code(result['code'], language='c')
|
| 496 |
-
|
| 497 |
-
st.markdown('</div>', unsafe_allow_html=True)
|
| 498 |
-
|
| 499 |
-
# Log stats
|
| 500 |
-
st.session_state.stats_logger.log_query({
|
| 501 |
-
"query_id": cache_result['query_id'],
|
| 502 |
-
"status": "hit",
|
| 503 |
-
"similarity_score": cache_result['similarity_scores'][0],
|
| 504 |
-
"confidence": cache_result['confidence'],
|
| 505 |
-
"response_time_ms": response_time,
|
| 506 |
-
"theme": theme,
|
| 507 |
-
"error_category": error_category,
|
| 508 |
-
"difficulty": difficulty,
|
| 509 |
-
"deepseek_tokens": 0,
|
| 510 |
-
"cache_size": st.session_state.collection.count()
|
| 511 |
-
})
|
| 512 |
-
|
| 513 |
-
# CACHE MISS
|
| 514 |
-
elif cache_result['status'] == 'miss':
|
| 515 |
-
st.markdown('<div class="miss-card">', unsafe_allow_html=True)
|
| 516 |
-
st.markdown("### Cache MISS - Generating New Feedback")
|
| 517 |
-
|
| 518 |
-
st.info(f" Closest match distance: {cache_result.get('closest_distance', 1.0):.4f} (threshold: {st.session_state.custom_threshold:.2f})")
|
| 519 |
-
|
| 520 |
-
# Afficher les codes les plus proches même en cas de miss
|
| 521 |
-
if cache_result['results']:
|
| 522 |
-
st.markdown("#### Closest matches found (but below threshold):")
|
| 523 |
-
for result in cache_result['results']:
|
| 524 |
-
# Calculer distance code_soumis ↔ code_référence
|
| 525 |
-
code_ref = result['code']
|
| 526 |
-
if code_ref and code_ref != 'N/A':
|
| 527 |
-
code_ref_embedding = encode_text(code_ref, st.session_state.model, st.session_state.tokenizer)
|
| 528 |
-
code_submitted_embedding = encode_text(code_input, st.session_state.model, st.session_state.tokenizer)
|
| 529 |
-
|
| 530 |
-
import numpy as np
|
| 531 |
-
similarity = np.dot(code_ref_embedding, code_submitted_embedding)
|
| 532 |
-
code_distance = 1 - similarity
|
| 533 |
-
else:
|
| 534 |
-
code_distance = None
|
| 535 |
-
|
| 536 |
-
with st.expander(f"Match #{result['rank']} (code→feedback: {result['distance']:.4f})"):
|
| 537 |
-
# Métriques côte à côte
|
| 538 |
-
col1, col2 = st.columns(2)
|
| 539 |
-
with col1:
|
| 540 |
-
st.metric("Code → Feedback", f"{result['distance']:.4f}", help="Distance bi-encoder (apprentissage)")
|
| 541 |
-
with col2:
|
| 542 |
-
if code_distance is not None:
|
| 543 |
-
st.metric("Code → Code Ref", f"{code_distance:.4f}", help="Distance code soumis vs code de référence")
|
| 544 |
-
|
| 545 |
-
st.markdown("**Feedback (given for reference code):**")
|
| 546 |
-
st.write(result['feedback'])
|
| 547 |
-
|
| 548 |
-
st.markdown("**Reference Code:**")
|
| 549 |
-
st.code(result['code'], language='c')
|
| 550 |
-
|
| 551 |
-
st.divider()
|
| 552 |
-
|
| 553 |
-
# Appeler DeepSeek
|
| 554 |
if st.session_state.deepseek_caller:
|
| 555 |
-
with st.spinner(" Generating
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
if
|
| 559 |
-
feedback =
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
st.success(" Feedback Generated!")
|
| 563 |
-
|
| 564 |
-
col1, col2, col3 = st.columns(3)
|
| 565 |
-
with col1:
|
| 566 |
-
st.metric("Tokens Used", tokens_used)
|
| 567 |
-
with col2:
|
| 568 |
-
st.metric("Generation Time", f"{deepseek_result['generation_time_ms']:.0f} ms")
|
| 569 |
-
with col3:
|
| 570 |
-
st.metric("Total Time", f"{response_time + deepseek_result['generation_time_ms']:.0f} ms")
|
| 571 |
-
|
| 572 |
-
st.markdown("**Generated Feedback:**")
|
| 573 |
st.write(feedback)
|
| 574 |
-
|
| 575 |
-
#
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
)
|
| 586 |
-
|
| 587 |
-
if success:
|
| 588 |
-
st.success(" Feedback added to cache for future queries!")
|
| 589 |
-
|
| 590 |
-
# Log cache miss (format dataset)
|
| 591 |
-
miss_data = {
|
| 592 |
-
**context,
|
| 593 |
-
"tags": [tag.strip() for tag in error_category.split(',') if tag.strip()] if error_category else [],
|
| 594 |
-
"feedback": feedback,
|
| 595 |
-
"query_id": cache_result['query_id'],
|
| 596 |
-
"tokens_used": tokens_used
|
| 597 |
-
}
|
| 598 |
-
st.session_state.stats_logger.log_cache_miss(miss_data)
|
| 599 |
-
|
| 600 |
-
# Log stats
|
| 601 |
-
st.session_state.stats_logger.log_query({
|
| 602 |
-
"query_id": cache_result['query_id'],
|
| 603 |
-
"status": "miss",
|
| 604 |
-
"similarity_score": cache_result.get('closest_distance', 1.0),
|
| 605 |
-
"confidence": 1.0, # LLM généré = haute confiance
|
| 606 |
-
"response_time_ms": response_time + deepseek_result['generation_time_ms'],
|
| 607 |
-
"theme": theme,
|
| 608 |
-
"error_category": error_category,
|
| 609 |
-
"difficulty": difficulty,
|
| 610 |
-
"deepseek_tokens": tokens_used,
|
| 611 |
-
"cache_size": st.session_state.collection.count()
|
| 612 |
-
})
|
| 613 |
else:
|
| 614 |
-
st.error(
|
| 615 |
else:
|
| 616 |
-
st.error("
|
| 617 |
-
|
| 618 |
-
st.markdown('</div>', unsafe_allow_html=True)
|
| 619 |
|
| 620 |
else:
|
| 621 |
-
st.info(" Please
|
| 622 |
-
|
| 623 |
-
st.markdown("""
|
| 624 |
-
### How to use:
|
| 625 |
-
1. **Load Model & Dataset** (or use cached DB)
|
| 626 |
-
2. **Fill in the form** with your code and its context
|
| 627 |
-
3. **Submit** to get feedback
|
| 628 |
-
4. **Check the Stats page** to see cache performance
|
| 629 |
-
|
| 630 |
-
### Cache System:
|
| 631 |
-
- **Hit**: Similar code found in database (instant response) Or Relevant feedabck code found in db with code feedback embedder
|
| 632 |
-
- **Miss**: No match found, generates new feedback (slower, uses API tokens)
|
| 633 |
-
- **Distillation**: New feedbacks are automatically added to the cache
|
| 634 |
-
""")
|
|
|
|
| 1 |
"""
|
| 2 |
+
Streamlit RAG Viewer avec Cache Intelligent (Static RAG Mode)
|
| 3 |
"""
|
| 4 |
|
| 5 |
import streamlit as st
|
| 6 |
import torch
|
| 7 |
import torch.nn.functional as F
|
| 8 |
from transformers import AutoTokenizer, AutoModel
|
|
|
|
| 9 |
import chromadb
|
| 10 |
from pathlib import Path
|
| 11 |
import json
|
| 12 |
import time
|
| 13 |
import logging
|
| 14 |
import sys
|
| 15 |
+
import os
|
| 16 |
+
from huggingface_hub import login, snapshot_download
|
| 17 |
+
|
| 18 |
# Import des modules custom
|
| 19 |
from cache_manager import CacheManager
|
| 20 |
from deepseek_caller import DeepSeekCaller
|
| 21 |
from stats_logger import StatsLogger
|
| 22 |
from config import DISTANCE_THRESHOLD
|
| 23 |
from utils import load_css
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# ==========================================
|
| 26 |
# PAGE CONFIG
|
| 27 |
# ==========================================
|
| 28 |
st.set_page_config(
|
| 29 |
page_title="RAG Feedback System",
|
| 30 |
+
page_icon="🧠",
|
| 31 |
layout="wide",
|
| 32 |
initial_sidebar_state="expanded"
|
| 33 |
)
|
| 34 |
|
| 35 |
+
# Configuration du Dataset HF contenant la DB Chroma
|
| 36 |
DATASET_ID = "matis35/chroma-rag-storage"
|
| 37 |
+
REPO_FOLDER = "chroma_db_storage"
|
|
|
|
|
|
|
|
|
|
| 38 |
LOCAL_CACHE_DIR = Path("./chroma_cache")
|
| 39 |
|
| 40 |
# ==========================================
|
|
|
|
| 46 |
# STATE MANAGEMENT
|
| 47 |
# ==========================================
|
| 48 |
if 'model_loaded' not in st.session_state: st.session_state.model_loaded = False
|
|
|
|
| 49 |
if 'db_initialized' not in st.session_state: st.session_state.db_initialized = False
|
| 50 |
if 'cache_manager' not in st.session_state: st.session_state.cache_manager = None
|
| 51 |
if 'deepseek_caller' not in st.session_state: st.session_state.deepseek_caller = None
|
| 52 |
if 'stats_logger' not in st.session_state: st.session_state.stats_logger = StatsLogger()
|
| 53 |
|
| 54 |
# ==========================================
|
| 55 |
+
# SETUP & LOGGING
|
| 56 |
# ==========================================
|
| 57 |
logging.basicConfig(
|
| 58 |
level=logging.INFO,
|
| 59 |
format='%(asctime)s | %(levelname)s | %(message)s',
|
| 60 |
datefmt='%H:%M:%S',
|
| 61 |
+
handlers=[logging.StreamHandler(sys.stdout)]
|
|
|
|
|
|
|
| 62 |
)
|
|
|
|
| 63 |
logger = logging.getLogger("FFGen_System")
|
| 64 |
+
|
| 65 |
+
# Authentification HF
|
| 66 |
hf_token = os.environ.get("HF_TOKEN")
|
| 67 |
+
if not hf_token and "HF_TOKEN" in st.secrets:
|
| 68 |
+
hf_token = st.secrets["HF_TOKEN"]
|
| 69 |
|
| 70 |
if hf_token:
|
|
|
|
| 71 |
login(token=hf_token)
|
| 72 |
+
|
| 73 |
+
# ==========================================
|
| 74 |
+
# CORE FUNCTIONS
|
| 75 |
+
# ==========================================
|
| 76 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
@st.cache_resource
|
| 78 |
def load_full_model(model_path: str):
|
| 79 |
+
"""Charge le modèle d'embedding (Hugging Face)"""
|
| 80 |
+
st.info(f"Loading embedding model from: {model_path}...")
|
|
|
|
| 81 |
try:
|
| 82 |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
| 83 |
if tokenizer.pad_token is None:
|
|
|
|
| 88 |
trust_remote_code=True,
|
| 89 |
device_map="auto"
|
| 90 |
)
|
|
|
|
| 91 |
model.eval()
|
| 92 |
return model, tokenizer
|
| 93 |
except Exception as e:
|
| 94 |
+
st.error(f"Failed to load model: {e}")
|
|
|
|
| 95 |
return None, None
|
| 96 |
|
| 97 |
def encode_text(text: str, model, tokenizer):
|
| 98 |
+
"""Génère l'embedding normalisé"""
|
| 99 |
device = next(model.parameters()).device
|
|
|
|
| 100 |
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True)
|
| 101 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
|
|
| 102 |
with torch.no_grad():
|
| 103 |
outputs = model(**inputs)
|
| 104 |
embeddings = outputs.last_hidden_state.mean(dim=1)
|
| 105 |
embeddings = F.normalize(embeddings, p=2, dim=1)
|
|
|
|
| 106 |
return embeddings[0].cpu().numpy().tolist()
|
| 107 |
|
| 108 |
+
@st.cache_resource
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
def initialize_chromadb():
|
| 110 |
"""
|
| 111 |
+
Télécharge la DB Chroma pré-calculée depuis Hugging Face.
|
| 112 |
+
Plus de re-indexation manuelle ici.
|
| 113 |
"""
|
|
|
|
|
|
|
|
|
|
| 114 |
final_db_path = LOCAL_CACHE_DIR / REPO_FOLDER
|
| 115 |
|
| 116 |
+
# 1. Téléchargement si absent
|
| 117 |
if not final_db_path.exists():
|
| 118 |
+
print(f"📥 Downloading vector DB from {DATASET_ID}...")
|
| 119 |
try:
|
| 120 |
snapshot_download(
|
| 121 |
repo_id=DATASET_ID,
|
| 122 |
repo_type="dataset",
|
| 123 |
+
local_dir=LOCAL_CACHE_DIR,
|
| 124 |
+
allow_patterns=[f"{REPO_FOLDER}/*"],
|
| 125 |
+
local_dir_use_symlinks=False
|
|
|
|
| 126 |
)
|
| 127 |
+
print("✅ Download complete.")
|
| 128 |
except Exception as e:
|
| 129 |
+
st.error(f"Failed to download DB: {e}")
|
|
|
|
| 130 |
raise e
|
| 131 |
|
| 132 |
+
# 2. Connexion
|
| 133 |
+
print(f"🔌 Connecting to ChromaDB at {final_db_path}")
|
| 134 |
client = chromadb.PersistentClient(path=str(final_db_path))
|
| 135 |
|
| 136 |
+
# 3. Vérification
|
|
|
|
|
|
|
| 137 |
try:
|
| 138 |
collection = client.get_collection(name="feedbacks")
|
| 139 |
+
print(f"📊 Collection loaded. Documents: {collection.count()}")
|
| 140 |
except Exception as e:
|
| 141 |
+
st.error("Collection 'feedbacks' not found in the downloaded DB.")
|
| 142 |
raise e
|
| 143 |
|
| 144 |
return client, collection
|
| 145 |
+
|
| 146 |
# ==========================================
|
| 147 |
+
# MAIN INTERFACE
|
| 148 |
# ==========================================
|
| 149 |
|
| 150 |
st.title("FFGEN")
|
| 151 |
st.markdown("### Submit code and get instant feedback")
|
| 152 |
|
| 153 |
+
# --- SIDEBAR ---
|
|
|
|
|
|
|
|
|
|
| 154 |
with st.sidebar:
|
| 155 |
+
st.header("⚙️ System Configuration")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
+
# Model Config
|
| 158 |
+
model_path = st.text_input("Embedding Model", value="matis35/gemmaembedding-fgdor")
|
| 159 |
+
|
| 160 |
st.divider()
|
| 161 |
+
|
| 162 |
+
# Cache Sensitivity
|
| 163 |
+
st.subheader("Cache Sensitivity")
|
|
|
|
|
|
|
| 164 |
if 'custom_threshold' not in st.session_state:
|
| 165 |
st.session_state.custom_threshold = DISTANCE_THRESHOLD
|
| 166 |
|
| 167 |
custom_threshold = st.slider(
|
| 168 |
+
"Similarity Threshold", 0.1, 1.0,
|
| 169 |
+
value=st.session_state.custom_threshold, step=0.05,
|
| 170 |
+
help="Lower = Stricter matching. Higher = More matches."
|
|
|
|
|
|
|
|
|
|
| 171 |
)
|
| 172 |
+
|
| 173 |
if custom_threshold != st.session_state.custom_threshold:
|
| 174 |
st.session_state.custom_threshold = custom_threshold
|
|
|
|
| 175 |
if st.session_state.get('cache_manager'):
|
| 176 |
st.session_state.cache_manager.threshold = custom_threshold
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
st.divider()
|
| 179 |
|
| 180 |
+
# Active Learning Toggle
|
| 181 |
+
enable_learning = st.checkbox(
|
| 182 |
+
"Enable Active Learning",
|
| 183 |
+
value=True,
|
| 184 |
+
help="If checked, new feedbacks generated by DeepSeek will be added to the local cache for this session."
|
| 185 |
+
)
|
| 186 |
|
| 187 |
+
st.divider()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
+
# Main Action Button
|
| 190 |
+
start_btn = st.button("🚀 Load System", use_container_width=True, type="primary")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
+
if start_btn:
|
| 193 |
+
# 1. Load Model
|
| 194 |
+
with st.spinner("1/2 Loading Neural Model..."):
|
| 195 |
model, tokenizer = load_full_model(model_path)
|
| 196 |
if model:
|
| 197 |
st.session_state.model = model
|
|
|
|
| 200 |
else:
|
| 201 |
st.stop()
|
| 202 |
|
| 203 |
+
# 2. Download & Connect DB
|
| 204 |
+
with st.spinner("2/2 Downloading & Connecting Vector DB..."):
|
| 205 |
try:
|
| 206 |
+
client, collection = initialize_chromadb() # Appel sans argument !
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
st.session_state.client = client
|
| 208 |
st.session_state.collection = collection
|
| 209 |
st.session_state.db_initialized = True
|
| 210 |
+
|
| 211 |
+
# Init Cache Manager
|
| 212 |
encoder_fn = lambda text: encode_text(text, model, tokenizer)
|
| 213 |
st.session_state.cache_manager = CacheManager(
|
| 214 |
collection,
|
| 215 |
encoder_fn,
|
| 216 |
threshold=st.session_state.custom_threshold
|
| 217 |
)
|
| 218 |
+
|
| 219 |
+
# Init DeepSeek
|
| 220 |
try:
|
| 221 |
st.session_state.deepseek_caller = DeepSeekCaller()
|
| 222 |
except:
|
| 223 |
+
st.warning("DeepSeek key not found, generation disabled.")
|
| 224 |
+
|
| 225 |
+
st.success("System Ready!")
|
| 226 |
+
time.sleep(1) # Petit temps pour voir le succès
|
| 227 |
+
st.rerun()
|
| 228 |
+
|
| 229 |
+
except Exception as e:
|
| 230 |
+
st.error(f"Initialization Error: {e}")
|
| 231 |
|
| 232 |
+
# --- MAIN LOGIC ---
|
|
|
|
|
|
|
| 233 |
|
| 234 |
if st.session_state.db_initialized and st.session_state.cache_manager:
|
| 235 |
+
|
| 236 |
+
# Formulaire de soumission
|
|
|
|
|
|
|
| 237 |
with st.form("code_submission"):
|
| 238 |
col1, col2 = st.columns([2, 1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
with col1:
|
| 240 |
+
code_input = st.text_area("C Code", height=300, placeholder="int main() { ... }")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
with col2:
|
| 242 |
+
theme = st.text_input("Theme", placeholder="e.g. Arrays")
|
| 243 |
+
difficulty = st.selectbox("Difficulty", ["beginner", "intermediate", "advanced"])
|
| 244 |
+
error_cat = st.text_input("Error Type (Optional)")
|
| 245 |
+
|
| 246 |
+
instructions = st.text_area("Instructions", placeholder="Function should return...")
|
| 247 |
+
submit_btn = st.form_submit_button("Search Feedback", use_container_width=True)
|
|
|
|
| 248 |
|
|
|
|
| 249 |
if submit_btn and code_input:
|
| 250 |
start_time = time.time()
|
|
|
|
|
|
|
| 251 |
context = {
|
| 252 |
+
"code": code_input, "theme": theme,
|
| 253 |
+
"difficulty": difficulty, "error_category": error_cat,
|
| 254 |
+
"instructions": instructions
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
}
|
| 256 |
|
| 257 |
+
# 1. Query Cache
|
| 258 |
+
with st.spinner("🔍 Searching knowledge base..."):
|
| 259 |
cache_result = st.session_state.cache_manager.query_cache(code_input, context)
|
| 260 |
+
|
| 261 |
+
elapsed = (time.time() - start_time) * 1000
|
| 262 |
|
| 263 |
+
# CAS 1: HIT ou PERFECT MATCH
|
|
|
|
|
|
|
| 264 |
if cache_result['status'] in ['hit', 'perfect_match']:
|
| 265 |
+
st.success(f"Feedback found in {elapsed:.0f}ms (Confidence: {cache_result['confidence']:.2f})")
|
| 266 |
+
|
| 267 |
+
# Affichage des résultats (Top 1)
|
| 268 |
+
best = cache_result['results'][0]
|
| 269 |
+
st.markdown("### 💡 Retrieved Feedback")
|
| 270 |
+
st.write(best['feedback'])
|
| 271 |
+
|
| 272 |
+
with st.expander("See Reference Code"):
|
| 273 |
+
st.code(best['code'], language='c')
|
| 274 |
+
st.caption(f"Distance: {best['distance']:.4f}")
|
| 275 |
|
| 276 |
+
# CAS 2: MISS -> GENERATION
|
| 277 |
+
else:
|
| 278 |
+
st.warning(f"No similar feedback found (Best distance: {cache_result.get('closest_distance', 1.0):.4f}). Generating new...")
|
| 279 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
if st.session_state.deepseek_caller:
|
| 281 |
+
with st.spinner("🤖 Generating analysis with DeepSeek..."):
|
| 282 |
+
gen_result = st.session_state.deepseek_caller.generate_feedback(context)
|
| 283 |
+
|
| 284 |
+
if 'feedback' in gen_result:
|
| 285 |
+
feedback = gen_result['feedback']
|
| 286 |
+
st.markdown("### 🤖 Generated Feedback")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
st.write(feedback)
|
| 288 |
+
|
| 289 |
+
# LOGIQUE D'APPRENTISSAGE (DISTILLATION)
|
| 290 |
+
if enable_learning:
|
| 291 |
+
with st.spinner("💾 Saving to local session cache..."):
|
| 292 |
+
emb = encode_text(feedback, st.session_state.model, st.session_state.tokenizer)
|
| 293 |
+
st.session_state.cache_manager.add_to_cache(
|
| 294 |
+
code=code_input,
|
| 295 |
+
feedback=feedback,
|
| 296 |
+
metadata=context,
|
| 297 |
+
embedding=emb
|
| 298 |
+
)
|
| 299 |
+
st.toast("Feedback added to cache!", icon="✅")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
else:
|
| 301 |
+
st.error("Generation failed.")
|
| 302 |
else:
|
| 303 |
+
st.error("DeepSeek not configured.")
|
|
|
|
|
|
|
| 304 |
|
| 305 |
else:
|
| 306 |
+
st.info("👈 Please load the system from the sidebar to start.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cache_manager.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
"""
|
| 2 |
-
Cache Manager - Gère Hit/Miss et distillation
|
| 3 |
"""
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
-
from typing import Dict, List, Any
|
| 7 |
import uuid
|
| 8 |
from datetime import datetime
|
| 9 |
from config import DISTANCE_THRESHOLD, TOP_K_RESULTS, CONFIDENCE_THRESHOLD_WARNING
|
|
@@ -14,61 +14,31 @@ class CacheManager:
|
|
| 14 |
Args:
|
| 15 |
chroma_collection: Collection ChromaDB
|
| 16 |
encoder_fn: Fonction pour encoder du texte en embedding
|
| 17 |
-
threshold: Custom similarity threshold
|
| 18 |
"""
|
| 19 |
self.collection = chroma_collection
|
| 20 |
self.encoder_fn = encoder_fn
|
| 21 |
self.threshold = threshold if threshold is not None else DISTANCE_THRESHOLD
|
| 22 |
|
| 23 |
def calculate_confidence(self, distances: List[float]) -> float:
|
| 24 |
-
"""
|
| 25 |
-
Calcule un score de confiance basé sur les distances.
|
| 26 |
-
Distance plus faible = confiance plus haute.
|
| 27 |
-
|
| 28 |
-
Returns:
|
| 29 |
-
float entre 0 et 1
|
| 30 |
-
"""
|
| 31 |
if not distances:
|
| 32 |
return 0.0
|
| 33 |
-
|
| 34 |
-
#
|
| 35 |
avg_distance = np.mean(distances)
|
| 36 |
-
|
| 37 |
-
# Convertir distance en confiance (inverse et normalisation)
|
| 38 |
-
# Distance de 0 = confiance 1.0
|
| 39 |
-
# Distance de 0.5 = confiance 0.5
|
| 40 |
-
# Distance de 1.0 = confiance 0.0
|
| 41 |
-
confidence = max(0.0, 1.0 - avg_distance)
|
| 42 |
-
|
| 43 |
-
return round(confidence, 3)
|
| 44 |
|
| 45 |
def query_cache(self, code: str, context: Dict[str, Any]) -> Dict[str, Any]:
|
| 46 |
"""
|
| 47 |
-
|
| 48 |
-
1. CHECK RAPIDE : Match exact de la chaîne de caractères (via Metadata).
|
| 49 |
-
-> Si trouvé : Retour immédiat (Stop).
|
| 50 |
-
|
| 51 |
-
2. RETRIEVAL : Recherche des 5 vecteurs les plus proches (Bi-Encoder).
|
| 52 |
-
|
| 53 |
-
3. ANALYSE FINE : Sur ces 5 candidats, on vérifie :
|
| 54 |
-
A. Est-ce qu'il y a un "Jumeau Sémantique" ? (Code quasi-identique > 0.95)
|
| 55 |
-
-> Si oui : C'est un HIT forcé (Priorité sur le seuil).
|
| 56 |
-
B. Est-ce que le meilleur candidat est sous le seuil de distance ?
|
| 57 |
-
-> Si oui : C'est un HIT standard.
|
| 58 |
-
|
| 59 |
-
4. DÉCISION : Si ni A ni B -> MISS.
|
| 60 |
"""
|
| 61 |
|
| 62 |
-
#
|
| 63 |
try:
|
| 64 |
-
# On vérifie si la chaîne de caractères brute existe déjà
|
| 65 |
if len(code) < 5000:
|
| 66 |
-
exact_matches = self.collection.get(
|
| 67 |
-
where={"code": code},
|
| 68 |
-
limit=1
|
| 69 |
-
)
|
| 70 |
if exact_matches and len(exact_matches['ids']) > 0:
|
| 71 |
-
print("Cache: MATCH EXACT (String) trouvé !")
|
| 72 |
return {
|
| 73 |
"status": "perfect_match",
|
| 74 |
"results": [{
|
|
@@ -78,22 +48,15 @@ class CacheManager:
|
|
| 78 |
"rank": 1,
|
| 79 |
"metadata": exact_matches['metadatas'][0]
|
| 80 |
}],
|
| 81 |
-
"similarity_scores": [0.0],
|
| 82 |
"confidence": 1.0,
|
| 83 |
-
"needs_deepseek": False,
|
| 84 |
"needs_warning": False,
|
| 85 |
-
"
|
| 86 |
-
"query_embedding": [],
|
| 87 |
-
"perfect_code_match": True
|
| 88 |
}
|
| 89 |
except Exception as e:
|
| 90 |
-
print(f"Warning exact match: {e}")
|
| 91 |
|
| 92 |
-
#
|
| 93 |
-
# On a besoin des candidats pour faire les analyses suivantes
|
| 94 |
-
|
| 95 |
query_embedding = self.encoder_fn(code)
|
| 96 |
-
|
| 97 |
query_results = self.collection.query(
|
| 98 |
query_embeddings=[query_embedding],
|
| 99 |
n_results=TOP_K_RESULTS
|
|
@@ -103,129 +66,88 @@ class CacheManager:
|
|
| 103 |
documents = query_results['documents'][0] if query_results['documents'] else []
|
| 104 |
metadatas = query_results['metadatas'][0] if query_results['metadatas'] else []
|
| 105 |
|
| 106 |
-
#
|
| 107 |
-
# On cherche un "Jumeau Sémantique" parmi les résultats retournés
|
| 108 |
-
code_similarity = None
|
| 109 |
perfect_code_match = False
|
|
|
|
| 110 |
|
| 111 |
-
#
|
| 112 |
if metadatas and metadatas[0].get('code'):
|
| 113 |
ref_code = metadatas[0].get('code')
|
| 114 |
if ref_code and ref_code != 'N/A':
|
|
|
|
| 115 |
ref_code_embedding = self.encoder_fn(ref_code)
|
| 116 |
-
# Produit scalaire
|
| 117 |
code_similarity = float(np.dot(query_embedding, ref_code_embedding))
|
| 118 |
-
|
| 119 |
-
# Si > 0.95, c'est le même code écrit différemment (ex: espaces, commentaires)
|
| 120 |
if code_similarity > 0.95:
|
| 121 |
perfect_code_match = True
|
| 122 |
|
| 123 |
-
#
|
| 124 |
-
|
| 125 |
-
# Condition A : Jumeau Sémantique (Le code est quasi identique)
|
| 126 |
-
# Condition B : Proximité Vectorielle Standard (Le sens est proche, sous le seuil)
|
| 127 |
-
|
| 128 |
is_hit = False
|
| 129 |
hit_type = "miss"
|
| 130 |
|
| 131 |
if perfect_code_match:
|
| 132 |
is_hit = True
|
| 133 |
-
hit_type = "perfect_match"
|
| 134 |
elif distances and distances[0] < self.threshold:
|
| 135 |
is_hit = True
|
| 136 |
-
hit_type = "hit"
|
| 137 |
|
| 138 |
-
#
|
| 139 |
-
|
| 140 |
-
# Préparation des résultats formatés (utilisé dans les deux cas)
|
| 141 |
formatted_results = []
|
| 142 |
-
for i, (feedback, metadata,
|
| 143 |
formatted_results.append({
|
| 144 |
"rank": i + 1,
|
| 145 |
"feedback": feedback,
|
| 146 |
"code": metadata.get('code', 'N/A'),
|
| 147 |
-
"distance": round(
|
| 148 |
"metadata": metadata
|
| 149 |
})
|
| 150 |
|
| 151 |
if is_hit:
|
| 152 |
-
# Calcul confiance
|
| 153 |
confidence = self.calculate_confidence(distances)
|
| 154 |
-
if perfect_code_match:
|
| 155 |
-
|
| 156 |
-
|
| 157 |
return {
|
| 158 |
"status": hit_type,
|
| 159 |
"results": formatted_results,
|
| 160 |
-
"
|
| 161 |
-
"confidence": confidence,
|
| 162 |
-
"needs_deepseek": False,
|
| 163 |
-
# Warning uniquement si c'est un hit "mou" (vecteur lointain) ET pas un match de code
|
| 164 |
"needs_warning": False if perfect_code_match else (confidence < CONFIDENCE_THRESHOLD_WARNING),
|
| 165 |
-
"
|
| 166 |
-
"query_id": str(uuid.uuid4()),
|
| 167 |
-
"code_similarity": round(code_similarity, 4) if code_similarity is not None else None,
|
| 168 |
-
"perfect_code_match": perfect_code_match
|
| 169 |
}
|
| 170 |
-
|
| 171 |
else:
|
| 172 |
-
# MISS
|
| 173 |
return {
|
| 174 |
"status": "miss",
|
| 175 |
-
"results": formatted_results,
|
| 176 |
-
"similarity_scores": [round(d, 4) for d in distances] if distances else [],
|
| 177 |
"confidence": 0.0,
|
| 178 |
-
"needs_deepseek": True,
|
| 179 |
"needs_warning": False,
|
| 180 |
-
"query_embedding": query_embedding,
|
| 181 |
-
"query_id": str(uuid.uuid4()),
|
| 182 |
"closest_distance": round(distances[0], 4) if distances else 1.0
|
| 183 |
}
|
|
|
|
| 184 |
def add_to_cache(self, code: str, feedback: str, metadata: Dict[str, Any], embedding: List[float]) -> bool:
|
| 185 |
"""
|
| 186 |
-
Ajoute
|
| 187 |
-
|
| 188 |
-
Args:
|
| 189 |
-
code: Code source
|
| 190 |
-
feedback: Feedback généré
|
| 191 |
-
metadata: Métadonnées complètes (theme, difficulty, etc.)
|
| 192 |
-
embedding: Embedding du feedback
|
| 193 |
-
|
| 194 |
-
Returns:
|
| 195 |
-
bool: True si succès
|
| 196 |
"""
|
| 197 |
try:
|
| 198 |
-
doc_id = f"
|
| 199 |
-
|
| 200 |
-
#
|
| 201 |
-
|
| 202 |
-
"code": code,
|
| 203 |
"timestamp": datetime.now().isoformat(),
|
| 204 |
-
"source": "
|
|
|
|
|
|
|
| 205 |
}
|
| 206 |
|
| 207 |
self.collection.add(
|
| 208 |
embeddings=[embedding],
|
| 209 |
documents=[feedback],
|
| 210 |
-
metadatas=[
|
| 211 |
ids=[doc_id]
|
| 212 |
)
|
| 213 |
-
|
| 214 |
return True
|
| 215 |
|
| 216 |
except Exception as e:
|
| 217 |
-
print(f"Error adding to cache: {e}")
|
| 218 |
-
return False
|
| 219 |
-
|
| 220 |
-
def get_cache_stats(self) -> Dict[str, Any]:
|
| 221 |
-
"""Retourne des stats sur le cache"""
|
| 222 |
-
try:
|
| 223 |
-
total_docs = self.collection.count()
|
| 224 |
-
|
| 225 |
-
return {
|
| 226 |
-
"total_documents": total_docs,
|
| 227 |
-
"similarity_threshold": SIMILARITY_THRESHOLD,
|
| 228 |
-
"top_k": TOP_K_RESULTS
|
| 229 |
-
}
|
| 230 |
-
except Exception as e:
|
| 231 |
-
return {"error": str(e)}
|
|
|
|
| 1 |
"""
|
| 2 |
+
Cache Manager - Gère Hit/Miss et distillation locale
|
| 3 |
"""
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
+
from typing import Dict, List, Any
|
| 7 |
import uuid
|
| 8 |
from datetime import datetime
|
| 9 |
from config import DISTANCE_THRESHOLD, TOP_K_RESULTS, CONFIDENCE_THRESHOLD_WARNING
|
|
|
|
| 14 |
Args:
|
| 15 |
chroma_collection: Collection ChromaDB
|
| 16 |
encoder_fn: Fonction pour encoder du texte en embedding
|
| 17 |
+
threshold: Custom similarity threshold
|
| 18 |
"""
|
| 19 |
self.collection = chroma_collection
|
| 20 |
self.encoder_fn = encoder_fn
|
| 21 |
self.threshold = threshold if threshold is not None else DISTANCE_THRESHOLD
|
| 22 |
|
| 23 |
def calculate_confidence(self, distances: List[float]) -> float:
|
| 24 |
+
"""Convertit la distance Chroma (Cosine) en score de confiance [0, 1]."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
if not distances:
|
| 26 |
return 0.0
|
| 27 |
+
# Avec hnsw:space="cosine", distance = 1 - similarity.
|
| 28 |
+
# Donc Similarity = 1 - distance.
|
| 29 |
avg_distance = np.mean(distances)
|
| 30 |
+
return max(0.0, 1.0 - avg_distance)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
def query_cache(self, code: str, context: Dict[str, Any]) -> Dict[str, Any]:
|
| 33 |
"""
|
| 34 |
+
Recherche dans le cache (Pipeline Hybride: Exact Match -> Vector Search -> Code Comparison)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
"""
|
| 36 |
|
| 37 |
+
# 1. CHECK RAPIDE (String Exact Match)
|
| 38 |
try:
|
|
|
|
| 39 |
if len(code) < 5000:
|
| 40 |
+
exact_matches = self.collection.get(where={"code": code}, limit=1)
|
|
|
|
|
|
|
|
|
|
| 41 |
if exact_matches and len(exact_matches['ids']) > 0:
|
|
|
|
| 42 |
return {
|
| 43 |
"status": "perfect_match",
|
| 44 |
"results": [{
|
|
|
|
| 48 |
"rank": 1,
|
| 49 |
"metadata": exact_matches['metadatas'][0]
|
| 50 |
}],
|
|
|
|
| 51 |
"confidence": 1.0,
|
|
|
|
| 52 |
"needs_warning": False,
|
| 53 |
+
"closest_distance": 0.0
|
|
|
|
|
|
|
| 54 |
}
|
| 55 |
except Exception as e:
|
| 56 |
+
print(f"Warning exact match check: {e}")
|
| 57 |
|
| 58 |
+
# 2. RETRIEVAL (Vectorielle)
|
|
|
|
|
|
|
| 59 |
query_embedding = self.encoder_fn(code)
|
|
|
|
| 60 |
query_results = self.collection.query(
|
| 61 |
query_embeddings=[query_embedding],
|
| 62 |
n_results=TOP_K_RESULTS
|
|
|
|
| 66 |
documents = query_results['documents'][0] if query_results['documents'] else []
|
| 67 |
metadatas = query_results['metadatas'][0] if query_results['metadatas'] else []
|
| 68 |
|
| 69 |
+
# 3. ANALYSE (Similarity Check)
|
|
|
|
|
|
|
| 70 |
perfect_code_match = False
|
| 71 |
+
code_similarity = 0.0
|
| 72 |
|
| 73 |
+
# Vérification sémantique du code sur le meilleur candidat
|
| 74 |
if metadatas and metadatas[0].get('code'):
|
| 75 |
ref_code = metadatas[0].get('code')
|
| 76 |
if ref_code and ref_code != 'N/A':
|
| 77 |
+
# On encode le code de référence pour comparer avec le code d'entrée
|
| 78 |
ref_code_embedding = self.encoder_fn(ref_code)
|
| 79 |
+
# Produit scalaire (approximatif si vecteurs normalisés)
|
| 80 |
code_similarity = float(np.dot(query_embedding, ref_code_embedding))
|
|
|
|
|
|
|
| 81 |
if code_similarity > 0.95:
|
| 82 |
perfect_code_match = True
|
| 83 |
|
| 84 |
+
# 4. DÉCISION
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
is_hit = False
|
| 86 |
hit_type = "miss"
|
| 87 |
|
| 88 |
if perfect_code_match:
|
| 89 |
is_hit = True
|
| 90 |
+
hit_type = "perfect_match"
|
| 91 |
elif distances and distances[0] < self.threshold:
|
| 92 |
is_hit = True
|
| 93 |
+
hit_type = "hit"
|
| 94 |
|
| 95 |
+
# Formatage des résultats
|
|
|
|
|
|
|
| 96 |
formatted_results = []
|
| 97 |
+
for i, (feedback, metadata, dist) in enumerate(zip(documents, metadatas, distances)):
|
| 98 |
formatted_results.append({
|
| 99 |
"rank": i + 1,
|
| 100 |
"feedback": feedback,
|
| 101 |
"code": metadata.get('code', 'N/A'),
|
| 102 |
+
"distance": round(dist, 4),
|
| 103 |
"metadata": metadata
|
| 104 |
})
|
| 105 |
|
| 106 |
if is_hit:
|
|
|
|
| 107 |
confidence = self.calculate_confidence(distances)
|
| 108 |
+
if perfect_code_match: confidence = 1.0
|
| 109 |
+
|
|
|
|
| 110 |
return {
|
| 111 |
"status": hit_type,
|
| 112 |
"results": formatted_results,
|
| 113 |
+
"confidence": round(confidence, 3),
|
|
|
|
|
|
|
|
|
|
| 114 |
"needs_warning": False if perfect_code_match else (confidence < CONFIDENCE_THRESHOLD_WARNING),
|
| 115 |
+
"closest_distance": round(distances[0], 4)
|
|
|
|
|
|
|
|
|
|
| 116 |
}
|
|
|
|
| 117 |
else:
|
|
|
|
| 118 |
return {
|
| 119 |
"status": "miss",
|
| 120 |
+
"results": formatted_results,
|
|
|
|
| 121 |
"confidence": 0.0,
|
|
|
|
| 122 |
"needs_warning": False,
|
|
|
|
|
|
|
| 123 |
"closest_distance": round(distances[0], 4) if distances else 1.0
|
| 124 |
}
|
| 125 |
+
|
| 126 |
def add_to_cache(self, code: str, feedback: str, metadata: Dict[str, Any], embedding: List[float]) -> bool:
|
| 127 |
"""
|
| 128 |
+
Ajoute au cache local pour la session courante.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
"""
|
| 130 |
try:
|
| 131 |
+
doc_id = f"learned_{uuid.uuid4().hex[:8]}"
|
| 132 |
+
|
| 133 |
+
# Nettoyage des métadonnées (Chroma n'aime pas les listes/None)
|
| 134 |
+
safe_metadata = {
|
| 135 |
+
"code": code[:10000], # Limite de taille
|
| 136 |
"timestamp": datetime.now().isoformat(),
|
| 137 |
+
"source": "active_learning",
|
| 138 |
+
"theme": str(metadata.get("theme", "")),
|
| 139 |
+
"difficulty": str(metadata.get("difficulty", ""))
|
| 140 |
}
|
| 141 |
|
| 142 |
self.collection.add(
|
| 143 |
embeddings=[embedding],
|
| 144 |
documents=[feedback],
|
| 145 |
+
metadatas=[safe_metadata],
|
| 146 |
ids=[doc_id]
|
| 147 |
)
|
| 148 |
+
print(f"✅ Learned new feedback: {doc_id}")
|
| 149 |
return True
|
| 150 |
|
| 151 |
except Exception as e:
|
| 152 |
+
print(f"❌ Error adding to cache: {e}")
|
| 153 |
+
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|