Matis Codjia commited on
Commit
3ee62c8
·
1 Parent(s): 3a00bdd

Auto load

Browse files
Files changed (2) hide show
  1. app.py +131 -459
  2. cache_manager.py +43 -121
app.py CHANGED
@@ -1,42 +1,40 @@
1
  """
2
- Streamlit RAG Viewer avec Cache Intelligent
3
  """
4
 
5
  import streamlit as st
6
  import torch
7
  import torch.nn.functional as F
8
  from transformers import AutoTokenizer, AutoModel
9
- from datasets import load_dataset
10
  import chromadb
11
  from pathlib import Path
12
  import json
13
  import time
14
  import logging
15
  import sys
 
 
 
16
  # Import des modules custom
17
  from cache_manager import CacheManager
18
  from deepseek_caller import DeepSeekCaller
19
  from stats_logger import StatsLogger
20
  from config import DISTANCE_THRESHOLD
21
  from utils import load_css
22
- from huggingface_hub import login, snapshot_download
23
- import os
24
 
25
  # ==========================================
26
  # PAGE CONFIG
27
  # ==========================================
28
  st.set_page_config(
29
  page_title="RAG Feedback System",
30
- page_icon="",
31
  layout="wide",
32
  initial_sidebar_state="expanded"
33
  )
34
 
 
35
  DATASET_ID = "matis35/chroma-rag-storage"
36
- REPO_FOLDER = "chroma_db_storage" # Le nom du dossier DANS le repo HF
37
-
38
- # Le dossier local où Streamlit va stocker la DB
39
- # On se met un niveau au-dessus pour que snapshot_download recrée le dossier "chroma_db_storage" dedans
40
  LOCAL_CACHE_DIR = Path("./chroma_cache")
41
 
42
  # ==========================================
@@ -48,45 +46,38 @@ load_css("assets/style.css")
48
  # STATE MANAGEMENT
49
  # ==========================================
50
  if 'model_loaded' not in st.session_state: st.session_state.model_loaded = False
51
- if 'dataset_loaded' not in st.session_state: st.session_state.dataset_loaded = False
52
  if 'db_initialized' not in st.session_state: st.session_state.db_initialized = False
53
  if 'cache_manager' not in st.session_state: st.session_state.cache_manager = None
54
  if 'deepseek_caller' not in st.session_state: st.session_state.deepseek_caller = None
55
  if 'stats_logger' not in st.session_state: st.session_state.stats_logger = StatsLogger()
56
 
57
  # ==========================================
58
- # HELPER FUNCTIONS
59
  # ==========================================
60
  logging.basicConfig(
61
  level=logging.INFO,
62
  format='%(asctime)s | %(levelname)s | %(message)s',
63
  datefmt='%H:%M:%S',
64
- handlers=[
65
- logging.StreamHandler(sys.stdout)
66
- ]
67
  )
68
-
69
  logger = logging.getLogger("FFGen_System")
 
 
70
  hf_token = os.environ.get("HF_TOKEN")
 
 
71
 
72
  if hf_token:
73
- # Se connecte explicitement
74
  login(token=hf_token)
75
- print("Successfully connected to huggingface")
76
- else:
77
- try:
78
- if "HF_TOKEN" in st.secrets:
79
- login(token=st.secrets["HF_TOKEN"])
80
- print("Connected via st.secrets")
81
- else:
82
- print("No HF key found")
83
- except FileNotFoundError:
84
- print("Local execution without secrets")
85
  @st.cache_resource
86
  def load_full_model(model_path: str):
87
- """Load standard HuggingFace model."""
88
- st.info(f"Loading model from: {model_path}")
89
- logger.info(f" Loading from: {model_path}...")
90
  try:
91
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
92
  if tokenizer.pad_token is None:
@@ -97,196 +88,110 @@ def load_full_model(model_path: str):
97
  trust_remote_code=True,
98
  device_map="auto"
99
  )
100
- logger.info(f"Modèle chargé avec succès !")
101
  model.eval()
102
  return model, tokenizer
103
  except Exception as e:
104
- st.error(f"Erreur de chargement: {e}")
105
- logger.error("Echec du chargement du modèle !")
106
  return None, None
107
 
108
  def encode_text(text: str, model, tokenizer):
109
- """Encode text to embedding."""
110
  device = next(model.parameters()).device
111
-
112
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True)
113
  inputs = {k: v.to(device) for k, v in inputs.items()}
114
-
115
  with torch.no_grad():
116
  outputs = model(**inputs)
117
  embeddings = outputs.last_hidden_state.mean(dim=1)
118
  embeddings = F.normalize(embeddings, p=2, dim=1)
119
-
120
  return embeddings[0].cpu().numpy().tolist()
121
 
122
- @st.cache_data
123
- def load_dataset_from_source(source: str, path: str):
124
- logger.info(f"Source séléctionnée {source}")
125
- if source == "HuggingFace Hub":
126
-
127
- dataset = load_dataset(path)
128
- data = []
129
- for split in dataset.keys():
130
- data.extend(dataset[split].to_list())
131
- return data
132
- else:
133
- data = []
134
- with open(path, 'r') as f:
135
- for line in f:
136
- if line.strip():
137
- data.append(json.loads(line))
138
- return data
139
-
140
- # @st.cache_resource # <--- Décommenter si tu es sous Streamlit pour ne le faire qu'une fois !
141
  def initialize_chromadb():
142
  """
143
- Mode Static RAG : Télécharge la DB depuis Hugging Face et se connecte en lecture seule.
 
144
  """
145
-
146
- # 1. CHEMIN CIBLE
147
- # Le chemin final sera : ./chroma_cache/chroma_db_storage
148
  final_db_path = LOCAL_CACHE_DIR / REPO_FOLDER
149
 
150
- # 2. TÉLÉCHARGEMENT (Si pas déjà présent)
151
  if not final_db_path.exists():
152
- print(f"📥 Téléchargement de la base depuis {DATASET_ID}...")
153
  try:
154
  snapshot_download(
155
  repo_id=DATASET_ID,
156
  repo_type="dataset",
157
- local_dir=LOCAL_CACHE_DIR, # On télécharge DANS le cache
158
- allow_patterns=[f"{REPO_FOLDER}/*"], # On ne prend que le dossier DB
159
- local_dir_use_symlinks=False,
160
- # token=os.environ.get("HF_TOKEN") # Nécessaire si le dataset est PRIVÉ
161
  )
162
- print("Téléchargement terminé.")
163
  except Exception as e:
164
- print(f" Erreur de téléchargement : {e}")
165
- # Fallback : Si on est en local et que le dossier existe déjà ailleurs, on pourrait pointer dessus
166
  raise e
167
 
168
- # 3. CONNEXION CHROMA
169
- # On pointe vers le dossier contenant le fichier sqlite3
170
  client = chromadb.PersistentClient(path=str(final_db_path))
171
 
172
- # 4. RÉCUPÉRATION DE LA COLLECTION
173
- # Attention : On ne fait plus de "create_collection" ni de "delete".
174
- # On récupère juste ce qui existe.
175
  try:
176
  collection = client.get_collection(name="feedbacks")
177
- print(f"📊 Collection chargée. {collection.count()} documents disponibles.")
178
  except Exception as e:
179
- print(f" Erreur : La collection 'feedbacks' n'existe pas dans la base téléchargée.")
180
  raise e
181
 
182
  return client, collection
 
183
  # ==========================================
184
- # MAIN APP
185
  # ==========================================
186
 
187
  st.title("FFGEN")
188
  st.markdown("### Submit code and get instant feedback")
189
 
190
- # ==========================================
191
- # SIDEBAR - CONFIGURATION
192
- # ==========================================
193
-
194
  with st.sidebar:
195
- st.header(" Configuration")
196
-
197
- # --- MODEL SELECTION ---
198
- st.subheader("Embedding Model")
199
- model_path = st.text_input(
200
- "Model Path (Local or HF)",
201
- value="matis35/gemmaembedding-fgdor",
202
- help="Path to embedding model"
203
- )
204
-
205
- # --- DATASET SELECTION ---
206
- st.subheader("Dataset")
207
- data_source = st.selectbox("Source", ["HuggingFace Hub", "Local JSONL"])
208
- dataset_path = st.text_input("Dataset Path", value="matis35/SYNT_V4")
209
 
 
 
 
210
  st.divider()
211
-
212
- # --- CACHE SETTINGS ---
213
- st.subheader("Cache Settings")
214
-
215
- # Permettre de modifier le threshold dynamiquement
216
  if 'custom_threshold' not in st.session_state:
217
  st.session_state.custom_threshold = DISTANCE_THRESHOLD
218
 
219
  custom_threshold = st.slider(
220
- "Semantic distance threshold",
221
- min_value=0.1,
222
- max_value=1.0,
223
- value=st.session_state.custom_threshold,
224
- step=0.05,
225
- help="Distance < threshold = HIT. Modifier cette valeur change le comportement du cache sans réindexer."
226
  )
227
-
228
  if custom_threshold != st.session_state.custom_threshold:
229
  st.session_state.custom_threshold = custom_threshold
230
- # Mettre à jour le threshold du cache manager existant si disponible
231
  if st.session_state.get('cache_manager'):
232
  st.session_state.cache_manager.threshold = custom_threshold
233
- st.info(f"Threshold updated to {custom_threshold:.2f}")
234
-
235
- st.caption(f"Current: Distance < {st.session_state.custom_threshold:.2f} = HIT")
236
 
237
  st.divider()
238
 
239
- force_reindex = st.checkbox("Force Re-index", value=False)
 
 
 
 
 
240
 
241
- col1, col2 = st.columns(2)
242
- with col1:
243
- load_btn = st.button("Load & Index", use_container_width=True)
244
- with col2:
245
- use_cached_btn = st.button(" Use Cached", use_container_width=True)
246
 
247
- # --- LOAD CACHED DB ---
248
- if use_cached_btn:
249
- try:
250
- client, collection = initialize_chromadb(force_reindex=False)
251
- count = collection.count()
252
- if count > 0:
253
- st.session_state.client = client
254
- st.session_state.collection = collection
255
- st.session_state.db_initialized = True
256
- st.success(f"DB Loaded: {count} docs")
257
- logger.info(f"Base de données démarrée avec succès: {count} instances")
258
- if not st.session_state.model_loaded:
259
- model, tokenizer = load_full_model(model_path)
260
- if model:
261
- st.session_state.model = model
262
- st.session_state.tokenizer = tokenizer
263
- st.session_state.model_loaded = True
264
-
265
- # Initialiser cache manager avec threshold dynamique
266
- encoder_fn = lambda text: encode_text(text, model, tokenizer)
267
- st.session_state.cache_manager = CacheManager(
268
- collection,
269
- encoder_fn,
270
- threshold=st.session_state.custom_threshold
271
- )
272
-
273
- # Initialiser DeepSeek caller
274
- try:
275
- st.session_state.deepseek_caller = DeepSeekCaller()
276
- st.success(" DeepSeek API Ready")
277
- logger.info("API prête")
278
- except Exception as e:
279
- st.warning(f" DeepSeek API unavailable: {e}")
280
- logger.error(f"API non disponible: {e}")
281
- else:
282
- st.warning(" Empty DB. Please Load & Index first.")
283
- except Exception as e:
284
- st.error(f"Error: {e}")
285
- logger.error(f"Problème avec la base de données: {e}")
286
 
287
- # --- LOAD AND INDEX ---
288
- if load_btn:
289
- with st.spinner("Loading Model..."):
290
  model, tokenizer = load_full_model(model_path)
291
  if model:
292
  st.session_state.model = model
@@ -295,340 +200,107 @@ with st.sidebar:
295
  else:
296
  st.stop()
297
 
298
- with st.spinner("Loading Dataset..."):
299
- logger.info("Chargement du dataset")
300
  try:
301
- data = load_dataset_from_source(data_source, dataset_path)
302
- st.session_state.dataset = data
303
- st.session_state.dataset_loaded = True
304
- except Exception as e:
305
- st.error(f"Dataset Error: {e}")
306
- logger.error("Problème de chargement du dataset")
307
- st.stop()
308
-
309
- if st.session_state.dataset_loaded:
310
- with st.spinner(f"Indexing {len(data)} items..."):
311
- client, collection = initialize_chromadb(force_reindex=force_reindex)
312
-
313
- batch_size = 64
314
- progress_bar = st.progress(0)
315
-
316
- for i in range(0, len(data), batch_size):
317
- batch = data[i:i+batch_size]
318
-
319
- feedbacks = [item.get("feedback", item.get("generated_feedback", "")) for item in batch]
320
- codes = [item.get("code") for item in batch]
321
-
322
- # IMPORTANT: Encode FEEDBACK for bi-encoder retrieval (code→feedback)
323
- embeddings = [encode_text(fb, model, tokenizer) for fb in feedbacks]
324
-
325
- # Store code as metadata for later comparison
326
- metadatas = [{"code": c if c else ""} for c in codes]
327
- ids = [f"id_{i+j}" for j in range(len(batch))]
328
-
329
- collection.add(
330
- embeddings=embeddings,
331
- documents=feedbacks,
332
- metadatas=metadatas,
333
- ids=ids
334
- )
335
- progress_bar.progress(min(1.0, (i + batch_size) / len(data)))
336
-
337
  st.session_state.client = client
338
  st.session_state.collection = collection
339
  st.session_state.db_initialized = True
340
-
341
- # Initialiser cache manager avec threshold dynamique
342
  encoder_fn = lambda text: encode_text(text, model, tokenizer)
343
  st.session_state.cache_manager = CacheManager(
344
  collection,
345
  encoder_fn,
346
  threshold=st.session_state.custom_threshold
347
  )
348
-
349
- # Initialiser DeepSeek
350
  try:
351
  st.session_state.deepseek_caller = DeepSeekCaller()
352
  except:
353
- pass
354
-
355
- st.success(" Indexing Complete!")
 
 
 
 
 
356
 
357
- # ==========================================
358
- # MAIN INTERFACE - QUERY
359
- # ==========================================
360
 
361
  if st.session_state.db_initialized and st.session_state.cache_manager:
362
-
363
- st.header(" Submit Your Code")
364
-
365
- # Formulaire enrichi
366
  with st.form("code_submission"):
367
  col1, col2 = st.columns([2, 1])
368
-
369
- with col1:
370
- code_input = st.text_area(
371
- "C Code",
372
- height=300,
373
- placeholder="Paste your C code here...",
374
- help="The code you want feedback on"
375
- )
376
-
377
- with col2:
378
- theme = st.text_input(
379
- "Exercise Theme",
380
- placeholder="e.g., Binary Search",
381
- help="What is this exercise about?"
382
- )
383
-
384
- difficulty = st.selectbox(
385
- "Difficulty Level",
386
- ["beginner", "intermediate", "advanced"]
387
- )
388
-
389
- error_category = st.text_input(
390
- "Error Category (optional)",
391
- placeholder="e.g., Off-by-one Error",
392
- help="If you know the type of error"
393
- )
394
-
395
- instructions = st.text_area(
396
- "Exercise Instructions (optional)",
397
- placeholder="Describe what the function should do...",
398
- help="Helps generate better feedback on cache miss"
399
- )
400
-
401
- col1, col2 = st.columns(2)
402
  with col1:
403
- test_scope = st.text_input(
404
- "Test Cases Scope (optional)",
405
- placeholder="e.g., Test with n=0, n=5, n=10",
406
- help="What tests should pass"
407
- )
408
-
409
  with col2:
410
- failed_tests = st.text_input(
411
- "Failed Tests (optional)",
412
- placeholder="e.g., Test n=0 returns wrong value",
413
- help="Which tests are failing"
414
- )
415
-
416
- submit_btn = st.form_submit_button(" Search Feedback", use_container_width=True)
417
 
418
- # TRAITEMENT DE LA REQUÊTE
419
  if submit_btn and code_input:
420
  start_time = time.time()
421
-
422
- # Contexte complet
423
  context = {
424
- "code": code_input,
425
- "theme": theme or "N/A",
426
- "difficulty": difficulty,
427
- "error_category": error_category or "Unknown",
428
- "instructions": instructions or "No instructions provided",
429
- "test_cases_scope": [test_scope] if test_scope else [],
430
- "failed_tests": [failed_tests] if failed_tests else []
431
  }
432
 
433
- # Query cache
434
- with st.spinner(" Searching cache..."):
435
  cache_result = st.session_state.cache_manager.query_cache(code_input, context)
 
 
436
 
437
- response_time = (time.time() - start_time) * 1000 # ms
438
-
439
- # CACHE HIT ou PERFECT MATCH
440
  if cache_result['status'] in ['hit', 'perfect_match']:
441
- is_perfect = cache_result['status'] == 'perfect_match'
442
-
443
- st.markdown('<div class="hit-card">', unsafe_allow_html=True)
 
 
 
 
 
 
 
444
 
445
- if is_perfect:
446
- st.markdown("### PERFECT CODE MATCH - Exact Feedback Found")
447
- st.success("The submitted code is identical (similarity > 95%) to a code in the database. This feedback is 100% accurate.")
448
- else:
449
- st.markdown("### Cache HIT - Feedback from Database")
450
-
451
- col1, col2, col3 = st.columns(3)
452
- with col1:
453
- st.metric("Confidence", f"{cache_result['confidence']:.2f}")
454
- with col2:
455
- st.metric("Best Match Distance (code→feedback)", f"{cache_result['similarity_scores'][0]:.4f}")
456
- with col3:
457
- st.metric("Response Time", f"{response_time:.0f} ms")
458
-
459
- # Afficher code similarity si disponible
460
- if cache_result.get('code_similarity') is not None:
461
- st.metric("Code Similarity", f"{cache_result['code_similarity']:.4f}",
462
- help="Similarity between your code and reference code (1.0 = identical)")
463
-
464
- if cache_result['needs_warning'] and not is_perfect:
465
- st.warning(" **Note:** Confidence is moderate. Review carefully.")
466
-
467
- # Afficher les résultats
468
- for result in cache_result['results']:
469
- # Calculer distance code_soumis ↔ code_référence
470
- code_ref = result['code']
471
- if code_ref and code_ref != 'N/A':
472
- code_ref_embedding = encode_text(code_ref, st.session_state.model, st.session_state.tokenizer)
473
- code_submitted_embedding = encode_text(code_input, st.session_state.model, st.session_state.tokenizer)
474
-
475
- # Cosine similarity
476
- import numpy as np
477
- similarity = np.dot(code_ref_embedding, code_submitted_embedding)
478
- code_distance = 1 - similarity
479
- else:
480
- code_distance = None
481
-
482
- with st.expander(f" Match #{result['rank']} (code→feedback distance: {result['distance']:.4f})"):
483
- # Métriques côte à côte
484
- col1, col2 = st.columns(2)
485
- with col1:
486
- st.metric("Code → Feedback", f"{result['distance']:.4f}", help="Distance entre votre code et ce feedback (apprentissage bi-encoder)")
487
- with col2:
488
- if code_distance is not None:
489
- st.metric("Code → Code Ref", f"{code_distance:.4f}", help="Distance entre votre code et le code de référence pour ce feedback")
490
-
491
- st.markdown("**Feedback:**")
492
- st.write(result['feedback'])
493
-
494
- st.markdown("**Reference Code (this feedback was given for):**")
495
- st.code(result['code'], language='c')
496
-
497
- st.markdown('</div>', unsafe_allow_html=True)
498
-
499
- # Log stats
500
- st.session_state.stats_logger.log_query({
501
- "query_id": cache_result['query_id'],
502
- "status": "hit",
503
- "similarity_score": cache_result['similarity_scores'][0],
504
- "confidence": cache_result['confidence'],
505
- "response_time_ms": response_time,
506
- "theme": theme,
507
- "error_category": error_category,
508
- "difficulty": difficulty,
509
- "deepseek_tokens": 0,
510
- "cache_size": st.session_state.collection.count()
511
- })
512
-
513
- # CACHE MISS
514
- elif cache_result['status'] == 'miss':
515
- st.markdown('<div class="miss-card">', unsafe_allow_html=True)
516
- st.markdown("### Cache MISS - Generating New Feedback")
517
-
518
- st.info(f" Closest match distance: {cache_result.get('closest_distance', 1.0):.4f} (threshold: {st.session_state.custom_threshold:.2f})")
519
-
520
- # Afficher les codes les plus proches même en cas de miss
521
- if cache_result['results']:
522
- st.markdown("#### Closest matches found (but below threshold):")
523
- for result in cache_result['results']:
524
- # Calculer distance code_soumis ↔ code_référence
525
- code_ref = result['code']
526
- if code_ref and code_ref != 'N/A':
527
- code_ref_embedding = encode_text(code_ref, st.session_state.model, st.session_state.tokenizer)
528
- code_submitted_embedding = encode_text(code_input, st.session_state.model, st.session_state.tokenizer)
529
-
530
- import numpy as np
531
- similarity = np.dot(code_ref_embedding, code_submitted_embedding)
532
- code_distance = 1 - similarity
533
- else:
534
- code_distance = None
535
-
536
- with st.expander(f"Match #{result['rank']} (code→feedback: {result['distance']:.4f})"):
537
- # Métriques côte à côte
538
- col1, col2 = st.columns(2)
539
- with col1:
540
- st.metric("Code → Feedback", f"{result['distance']:.4f}", help="Distance bi-encoder (apprentissage)")
541
- with col2:
542
- if code_distance is not None:
543
- st.metric("Code → Code Ref", f"{code_distance:.4f}", help="Distance code soumis vs code de référence")
544
-
545
- st.markdown("**Feedback (given for reference code):**")
546
- st.write(result['feedback'])
547
-
548
- st.markdown("**Reference Code:**")
549
- st.code(result['code'], language='c')
550
-
551
- st.divider()
552
-
553
- # Appeler DeepSeek
554
  if st.session_state.deepseek_caller:
555
- with st.spinner(" Generating feedback with DeepSeek..."):
556
- deepseek_result = st.session_state.deepseek_caller.generate_feedback(context)
557
-
558
- if deepseek_result.get('feedback'):
559
- feedback = deepseek_result['feedback']
560
- tokens_used = deepseek_result['tokens_total']
561
-
562
- st.success(" Feedback Generated!")
563
-
564
- col1, col2, col3 = st.columns(3)
565
- with col1:
566
- st.metric("Tokens Used", tokens_used)
567
- with col2:
568
- st.metric("Generation Time", f"{deepseek_result['generation_time_ms']:.0f} ms")
569
- with col3:
570
- st.metric("Total Time", f"{response_time + deepseek_result['generation_time_ms']:.0f} ms")
571
-
572
- st.markdown("**Generated Feedback:**")
573
  st.write(feedback)
574
-
575
- # Distillation : Ajouter au cache
576
- with st.spinner(" Adding to cache (distillation)..."):
577
- # Encoder le feedback
578
- feedback_embedding = encode_text(feedback, st.session_state.model, st.session_state.tokenizer)
579
-
580
- success = st.session_state.cache_manager.add_to_cache(
581
- code=code_input,
582
- feedback=feedback,
583
- metadata=context,
584
- embedding=feedback_embedding
585
- )
586
-
587
- if success:
588
- st.success(" Feedback added to cache for future queries!")
589
-
590
- # Log cache miss (format dataset)
591
- miss_data = {
592
- **context,
593
- "tags": [tag.strip() for tag in error_category.split(',') if tag.strip()] if error_category else [],
594
- "feedback": feedback,
595
- "query_id": cache_result['query_id'],
596
- "tokens_used": tokens_used
597
- }
598
- st.session_state.stats_logger.log_cache_miss(miss_data)
599
-
600
- # Log stats
601
- st.session_state.stats_logger.log_query({
602
- "query_id": cache_result['query_id'],
603
- "status": "miss",
604
- "similarity_score": cache_result.get('closest_distance', 1.0),
605
- "confidence": 1.0, # LLM généré = haute confiance
606
- "response_time_ms": response_time + deepseek_result['generation_time_ms'],
607
- "theme": theme,
608
- "error_category": error_category,
609
- "difficulty": difficulty,
610
- "deepseek_tokens": tokens_used,
611
- "cache_size": st.session_state.collection.count()
612
- })
613
  else:
614
- st.error(f" Error: {deepseek_result.get('error', 'Unknown error')}")
615
  else:
616
- st.error(" DeepSeek API not available. Cannot generate feedback.")
617
-
618
- st.markdown('</div>', unsafe_allow_html=True)
619
 
620
  else:
621
- st.info(" Please configure and load the model + dataset from the sidebar first.")
622
-
623
- st.markdown("""
624
- ### How to use:
625
- 1. **Load Model & Dataset** (or use cached DB)
626
- 2. **Fill in the form** with your code and its context
627
- 3. **Submit** to get feedback
628
- 4. **Check the Stats page** to see cache performance
629
-
630
- ### Cache System:
631
- - **Hit**: Similar code found in database (instant response) Or Relevant feedabck code found in db with code feedback embedder
632
- - **Miss**: No match found, generates new feedback (slower, uses API tokens)
633
- - **Distillation**: New feedbacks are automatically added to the cache
634
- """)
 
1
  """
2
+ Streamlit RAG Viewer avec Cache Intelligent (Static RAG Mode)
3
  """
4
 
5
  import streamlit as st
6
  import torch
7
  import torch.nn.functional as F
8
  from transformers import AutoTokenizer, AutoModel
 
9
  import chromadb
10
  from pathlib import Path
11
  import json
12
  import time
13
  import logging
14
  import sys
15
+ import os
16
+ from huggingface_hub import login, snapshot_download
17
+
18
  # Import des modules custom
19
  from cache_manager import CacheManager
20
  from deepseek_caller import DeepSeekCaller
21
  from stats_logger import StatsLogger
22
  from config import DISTANCE_THRESHOLD
23
  from utils import load_css
 
 
24
 
25
  # ==========================================
26
  # PAGE CONFIG
27
  # ==========================================
28
  st.set_page_config(
29
  page_title="RAG Feedback System",
30
+ page_icon="🧠",
31
  layout="wide",
32
  initial_sidebar_state="expanded"
33
  )
34
 
35
+ # Configuration du Dataset HF contenant la DB Chroma
36
  DATASET_ID = "matis35/chroma-rag-storage"
37
+ REPO_FOLDER = "chroma_db_storage"
 
 
 
38
  LOCAL_CACHE_DIR = Path("./chroma_cache")
39
 
40
  # ==========================================
 
46
  # STATE MANAGEMENT
47
  # ==========================================
48
  if 'model_loaded' not in st.session_state: st.session_state.model_loaded = False
 
49
  if 'db_initialized' not in st.session_state: st.session_state.db_initialized = False
50
  if 'cache_manager' not in st.session_state: st.session_state.cache_manager = None
51
  if 'deepseek_caller' not in st.session_state: st.session_state.deepseek_caller = None
52
  if 'stats_logger' not in st.session_state: st.session_state.stats_logger = StatsLogger()
53
 
54
  # ==========================================
55
+ # SETUP & LOGGING
56
  # ==========================================
57
  logging.basicConfig(
58
  level=logging.INFO,
59
  format='%(asctime)s | %(levelname)s | %(message)s',
60
  datefmt='%H:%M:%S',
61
+ handlers=[logging.StreamHandler(sys.stdout)]
 
 
62
  )
 
63
  logger = logging.getLogger("FFGen_System")
64
+
65
+ # Authentification HF
66
  hf_token = os.environ.get("HF_TOKEN")
67
+ if not hf_token and "HF_TOKEN" in st.secrets:
68
+ hf_token = st.secrets["HF_TOKEN"]
69
 
70
  if hf_token:
 
71
  login(token=hf_token)
72
+
73
+ # ==========================================
74
+ # CORE FUNCTIONS
75
+ # ==========================================
76
+
 
 
 
 
 
77
  @st.cache_resource
78
  def load_full_model(model_path: str):
79
+ """Charge le modèle d'embedding (Hugging Face)"""
80
+ st.info(f"Loading embedding model from: {model_path}...")
 
81
  try:
82
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
83
  if tokenizer.pad_token is None:
 
88
  trust_remote_code=True,
89
  device_map="auto"
90
  )
 
91
  model.eval()
92
  return model, tokenizer
93
  except Exception as e:
94
+ st.error(f"Failed to load model: {e}")
 
95
  return None, None
96
 
97
  def encode_text(text: str, model, tokenizer):
98
+ """Génère l'embedding normalisé"""
99
  device = next(model.parameters()).device
 
100
  inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True)
101
  inputs = {k: v.to(device) for k, v in inputs.items()}
 
102
  with torch.no_grad():
103
  outputs = model(**inputs)
104
  embeddings = outputs.last_hidden_state.mean(dim=1)
105
  embeddings = F.normalize(embeddings, p=2, dim=1)
 
106
  return embeddings[0].cpu().numpy().tolist()
107
 
108
+ @st.cache_resource
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  def initialize_chromadb():
110
  """
111
+ Télécharge la DB Chroma pré-calculée depuis Hugging Face.
112
+ Plus de re-indexation manuelle ici.
113
  """
 
 
 
114
  final_db_path = LOCAL_CACHE_DIR / REPO_FOLDER
115
 
116
+ # 1. Téléchargement si absent
117
  if not final_db_path.exists():
118
+ print(f"📥 Downloading vector DB from {DATASET_ID}...")
119
  try:
120
  snapshot_download(
121
  repo_id=DATASET_ID,
122
  repo_type="dataset",
123
+ local_dir=LOCAL_CACHE_DIR,
124
+ allow_patterns=[f"{REPO_FOLDER}/*"],
125
+ local_dir_use_symlinks=False
 
126
  )
127
+ print(" Download complete.")
128
  except Exception as e:
129
+ st.error(f"Failed to download DB: {e}")
 
130
  raise e
131
 
132
+ # 2. Connexion
133
+ print(f"🔌 Connecting to ChromaDB at {final_db_path}")
134
  client = chromadb.PersistentClient(path=str(final_db_path))
135
 
136
+ # 3. Vérification
 
 
137
  try:
138
  collection = client.get_collection(name="feedbacks")
139
+ print(f"📊 Collection loaded. Documents: {collection.count()}")
140
  except Exception as e:
141
+ st.error("Collection 'feedbacks' not found in the downloaded DB.")
142
  raise e
143
 
144
  return client, collection
145
+
146
  # ==========================================
147
+ # MAIN INTERFACE
148
  # ==========================================
149
 
150
  st.title("FFGEN")
151
  st.markdown("### Submit code and get instant feedback")
152
 
153
+ # --- SIDEBAR ---
 
 
 
154
  with st.sidebar:
155
+ st.header("⚙️ System Configuration")
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
+ # Model Config
158
+ model_path = st.text_input("Embedding Model", value="matis35/gemmaembedding-fgdor")
159
+
160
  st.divider()
161
+
162
+ # Cache Sensitivity
163
+ st.subheader("Cache Sensitivity")
 
 
164
  if 'custom_threshold' not in st.session_state:
165
  st.session_state.custom_threshold = DISTANCE_THRESHOLD
166
 
167
  custom_threshold = st.slider(
168
+ "Similarity Threshold", 0.1, 1.0,
169
+ value=st.session_state.custom_threshold, step=0.05,
170
+ help="Lower = Stricter matching. Higher = More matches."
 
 
 
171
  )
172
+
173
  if custom_threshold != st.session_state.custom_threshold:
174
  st.session_state.custom_threshold = custom_threshold
 
175
  if st.session_state.get('cache_manager'):
176
  st.session_state.cache_manager.threshold = custom_threshold
 
 
 
177
 
178
  st.divider()
179
 
180
+ # Active Learning Toggle
181
+ enable_learning = st.checkbox(
182
+ "Enable Active Learning",
183
+ value=True,
184
+ help="If checked, new feedbacks generated by DeepSeek will be added to the local cache for this session."
185
+ )
186
 
187
+ st.divider()
 
 
 
 
188
 
189
+ # Main Action Button
190
+ start_btn = st.button("🚀 Load System", use_container_width=True, type="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
+ if start_btn:
193
+ # 1. Load Model
194
+ with st.spinner("1/2 Loading Neural Model..."):
195
  model, tokenizer = load_full_model(model_path)
196
  if model:
197
  st.session_state.model = model
 
200
  else:
201
  st.stop()
202
 
203
+ # 2. Download & Connect DB
204
+ with st.spinner("2/2 Downloading & Connecting Vector DB..."):
205
  try:
206
+ client, collection = initialize_chromadb() # Appel sans argument !
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  st.session_state.client = client
208
  st.session_state.collection = collection
209
  st.session_state.db_initialized = True
210
+
211
+ # Init Cache Manager
212
  encoder_fn = lambda text: encode_text(text, model, tokenizer)
213
  st.session_state.cache_manager = CacheManager(
214
  collection,
215
  encoder_fn,
216
  threshold=st.session_state.custom_threshold
217
  )
218
+
219
+ # Init DeepSeek
220
  try:
221
  st.session_state.deepseek_caller = DeepSeekCaller()
222
  except:
223
+ st.warning("DeepSeek key not found, generation disabled.")
224
+
225
+ st.success("System Ready!")
226
+ time.sleep(1) # Petit temps pour voir le succès
227
+ st.rerun()
228
+
229
+ except Exception as e:
230
+ st.error(f"Initialization Error: {e}")
231
 
232
+ # --- MAIN LOGIC ---
 
 
233
 
234
  if st.session_state.db_initialized and st.session_state.cache_manager:
235
+
236
+ # Formulaire de soumission
 
 
237
  with st.form("code_submission"):
238
  col1, col2 = st.columns([2, 1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  with col1:
240
+ code_input = st.text_area("C Code", height=300, placeholder="int main() { ... }")
 
 
 
 
 
241
  with col2:
242
+ theme = st.text_input("Theme", placeholder="e.g. Arrays")
243
+ difficulty = st.selectbox("Difficulty", ["beginner", "intermediate", "advanced"])
244
+ error_cat = st.text_input("Error Type (Optional)")
245
+
246
+ instructions = st.text_area("Instructions", placeholder="Function should return...")
247
+ submit_btn = st.form_submit_button("Search Feedback", use_container_width=True)
 
248
 
 
249
  if submit_btn and code_input:
250
  start_time = time.time()
 
 
251
  context = {
252
+ "code": code_input, "theme": theme,
253
+ "difficulty": difficulty, "error_category": error_cat,
254
+ "instructions": instructions
 
 
 
 
255
  }
256
 
257
+ # 1. Query Cache
258
+ with st.spinner("🔍 Searching knowledge base..."):
259
  cache_result = st.session_state.cache_manager.query_cache(code_input, context)
260
+
261
+ elapsed = (time.time() - start_time) * 1000
262
 
263
+ # CAS 1: HIT ou PERFECT MATCH
 
 
264
  if cache_result['status'] in ['hit', 'perfect_match']:
265
+ st.success(f"Feedback found in {elapsed:.0f}ms (Confidence: {cache_result['confidence']:.2f})")
266
+
267
+ # Affichage des résultats (Top 1)
268
+ best = cache_result['results'][0]
269
+ st.markdown("### 💡 Retrieved Feedback")
270
+ st.write(best['feedback'])
271
+
272
+ with st.expander("See Reference Code"):
273
+ st.code(best['code'], language='c')
274
+ st.caption(f"Distance: {best['distance']:.4f}")
275
 
276
+ # CAS 2: MISS -> GENERATION
277
+ else:
278
+ st.warning(f"No similar feedback found (Best distance: {cache_result.get('closest_distance', 1.0):.4f}). Generating new...")
279
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  if st.session_state.deepseek_caller:
281
+ with st.spinner("🤖 Generating analysis with DeepSeek..."):
282
+ gen_result = st.session_state.deepseek_caller.generate_feedback(context)
283
+
284
+ if 'feedback' in gen_result:
285
+ feedback = gen_result['feedback']
286
+ st.markdown("### 🤖 Generated Feedback")
 
 
 
 
 
 
 
 
 
 
 
 
287
  st.write(feedback)
288
+
289
+ # LOGIQUE D'APPRENTISSAGE (DISTILLATION)
290
+ if enable_learning:
291
+ with st.spinner("💾 Saving to local session cache..."):
292
+ emb = encode_text(feedback, st.session_state.model, st.session_state.tokenizer)
293
+ st.session_state.cache_manager.add_to_cache(
294
+ code=code_input,
295
+ feedback=feedback,
296
+ metadata=context,
297
+ embedding=emb
298
+ )
299
+ st.toast("Feedback added to cache!", icon="✅")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  else:
301
+ st.error("Generation failed.")
302
  else:
303
+ st.error("DeepSeek not configured.")
 
 
304
 
305
  else:
306
+ st.info("👈 Please load the system from the sidebar to start.")
 
 
 
 
 
 
 
 
 
 
 
 
 
cache_manager.py CHANGED
@@ -1,9 +1,9 @@
1
  """
2
- Cache Manager - Gère Hit/Miss et distillation
3
  """
4
 
5
  import numpy as np
6
- from typing import Dict, List, Any, Tuple
7
  import uuid
8
  from datetime import datetime
9
  from config import DISTANCE_THRESHOLD, TOP_K_RESULTS, CONFIDENCE_THRESHOLD_WARNING
@@ -14,61 +14,31 @@ class CacheManager:
14
  Args:
15
  chroma_collection: Collection ChromaDB
16
  encoder_fn: Fonction pour encoder du texte en embedding
17
- threshold: Custom similarity threshold (if None, uses config default)
18
  """
19
  self.collection = chroma_collection
20
  self.encoder_fn = encoder_fn
21
  self.threshold = threshold if threshold is not None else DISTANCE_THRESHOLD
22
 
23
  def calculate_confidence(self, distances: List[float]) -> float:
24
- """
25
- Calcule un score de confiance basé sur les distances.
26
- Distance plus faible = confiance plus haute.
27
-
28
- Returns:
29
- float entre 0 et 1
30
- """
31
  if not distances:
32
  return 0.0
33
-
34
- # Distance moyenne
35
  avg_distance = np.mean(distances)
36
-
37
- # Convertir distance en confiance (inverse et normalisation)
38
- # Distance de 0 = confiance 1.0
39
- # Distance de 0.5 = confiance 0.5
40
- # Distance de 1.0 = confiance 0.0
41
- confidence = max(0.0, 1.0 - avg_distance)
42
-
43
- return round(confidence, 3)
44
 
45
  def query_cache(self, code: str, context: Dict[str, Any]) -> Dict[str, Any]:
46
  """
47
- Logique d'exécution (Pipeline) :
48
- 1. CHECK RAPIDE : Match exact de la chaîne de caractères (via Metadata).
49
- -> Si trouvé : Retour immédiat (Stop).
50
-
51
- 2. RETRIEVAL : Recherche des 5 vecteurs les plus proches (Bi-Encoder).
52
-
53
- 3. ANALYSE FINE : Sur ces 5 candidats, on vérifie :
54
- A. Est-ce qu'il y a un "Jumeau Sémantique" ? (Code quasi-identique > 0.95)
55
- -> Si oui : C'est un HIT forcé (Priorité sur le seuil).
56
- B. Est-ce que le meilleur candidat est sous le seuil de distance ?
57
- -> Si oui : C'est un HIT standard.
58
-
59
- 4. DÉCISION : Si ni A ni B -> MISS.
60
  """
61
 
62
- # --- ÉTAPE 1 : CHECK RAPIDE (String Exact Match) ---
63
  try:
64
- # On vérifie si la chaîne de caractères brute existe déjà
65
  if len(code) < 5000:
66
- exact_matches = self.collection.get(
67
- where={"code": code},
68
- limit=1
69
- )
70
  if exact_matches and len(exact_matches['ids']) > 0:
71
- print("Cache: MATCH EXACT (String) trouvé !")
72
  return {
73
  "status": "perfect_match",
74
  "results": [{
@@ -78,22 +48,15 @@ class CacheManager:
78
  "rank": 1,
79
  "metadata": exact_matches['metadatas'][0]
80
  }],
81
- "similarity_scores": [0.0],
82
  "confidence": 1.0,
83
- "needs_deepseek": False,
84
  "needs_warning": False,
85
- "query_id": str(uuid.uuid4()),
86
- "query_embedding": [],
87
- "perfect_code_match": True
88
  }
89
  except Exception as e:
90
- print(f"Warning exact match: {e}")
91
 
92
- # --- ÉTAPE 2 : RETRIEVAL (Recherche Vectorielle) ---
93
- # On a besoin des candidats pour faire les analyses suivantes
94
-
95
  query_embedding = self.encoder_fn(code)
96
-
97
  query_results = self.collection.query(
98
  query_embeddings=[query_embedding],
99
  n_results=TOP_K_RESULTS
@@ -103,129 +66,88 @@ class CacheManager:
103
  documents = query_results['documents'][0] if query_results['documents'] else []
104
  metadatas = query_results['metadatas'][0] if query_results['metadatas'] else []
105
 
106
- # --- ÉTAPE 3 : ANALYSE FINE (Code Similarity Check) ---
107
- # On cherche un "Jumeau Sémantique" parmi les résultats retournés
108
- code_similarity = None
109
  perfect_code_match = False
 
110
 
111
- # On regarde uniquement le meilleur candidat (rank 1) pour la comparaison code-à-code
112
  if metadatas and metadatas[0].get('code'):
113
  ref_code = metadatas[0].get('code')
114
  if ref_code and ref_code != 'N/A':
 
115
  ref_code_embedding = self.encoder_fn(ref_code)
116
- # Produit scalaire
117
  code_similarity = float(np.dot(query_embedding, ref_code_embedding))
118
-
119
- # Si > 0.95, c'est le même code écrit différemment (ex: espaces, commentaires)
120
  if code_similarity > 0.95:
121
  perfect_code_match = True
122
 
123
- # --- ÉTAPE 4 : DÉCISION HIT / MISS ---
124
-
125
- # Condition A : Jumeau Sémantique (Le code est quasi identique)
126
- # Condition B : Proximité Vectorielle Standard (Le sens est proche, sous le seuil)
127
-
128
  is_hit = False
129
  hit_type = "miss"
130
 
131
  if perfect_code_match:
132
  is_hit = True
133
- hit_type = "perfect_match" # Priorité haute
134
  elif distances and distances[0] < self.threshold:
135
  is_hit = True
136
- hit_type = "hit" # Priorité standard
137
 
138
- # --- CONSTRUCTION DE LA RÉPONSE ---
139
-
140
- # Préparation des résultats formatés (utilisé dans les deux cas)
141
  formatted_results = []
142
- for i, (feedback, metadata, distance) in enumerate(zip(documents, metadatas, distances)):
143
  formatted_results.append({
144
  "rank": i + 1,
145
  "feedback": feedback,
146
  "code": metadata.get('code', 'N/A'),
147
- "distance": round(distance, 4),
148
  "metadata": metadata
149
  })
150
 
151
  if is_hit:
152
- # Calcul confiance
153
  confidence = self.calculate_confidence(distances)
154
- if perfect_code_match:
155
- confidence = 1.0 # Boost max car on est sûr du code
156
-
157
  return {
158
  "status": hit_type,
159
  "results": formatted_results,
160
- "similarity_scores": [round(d, 4) for d in distances],
161
- "confidence": confidence,
162
- "needs_deepseek": False,
163
- # Warning uniquement si c'est un hit "mou" (vecteur lointain) ET pas un match de code
164
  "needs_warning": False if perfect_code_match else (confidence < CONFIDENCE_THRESHOLD_WARNING),
165
- "query_embedding": query_embedding,
166
- "query_id": str(uuid.uuid4()),
167
- "code_similarity": round(code_similarity, 4) if code_similarity is not None else None,
168
- "perfect_code_match": perfect_code_match
169
  }
170
-
171
  else:
172
- # MISS
173
  return {
174
  "status": "miss",
175
- "results": formatted_results, # On renvoie quand même les proches pour info
176
- "similarity_scores": [round(d, 4) for d in distances] if distances else [],
177
  "confidence": 0.0,
178
- "needs_deepseek": True,
179
  "needs_warning": False,
180
- "query_embedding": query_embedding,
181
- "query_id": str(uuid.uuid4()),
182
  "closest_distance": round(distances[0], 4) if distances else 1.0
183
  }
 
184
  def add_to_cache(self, code: str, feedback: str, metadata: Dict[str, Any], embedding: List[float]) -> bool:
185
  """
186
- Ajoute une nouvelle entrée au cache (distillation online).
187
-
188
- Args:
189
- code: Code source
190
- feedback: Feedback généré
191
- metadata: Métadonnées complètes (theme, difficulty, etc.)
192
- embedding: Embedding du feedback
193
-
194
- Returns:
195
- bool: True si succès
196
  """
197
  try:
198
- doc_id = f"miss_{uuid.uuid4()}"
199
-
200
- # Préparer metadata pour ChromaDB (seulement le code car limitation)
201
- chroma_metadata = {
202
- "code": code,
203
  "timestamp": datetime.now().isoformat(),
204
- "source": "cache_miss"
 
 
205
  }
206
 
207
  self.collection.add(
208
  embeddings=[embedding],
209
  documents=[feedback],
210
- metadatas=[chroma_metadata],
211
  ids=[doc_id]
212
  )
213
-
214
  return True
215
 
216
  except Exception as e:
217
- print(f"Error adding to cache: {e}")
218
- return False
219
-
220
- def get_cache_stats(self) -> Dict[str, Any]:
221
- """Retourne des stats sur le cache"""
222
- try:
223
- total_docs = self.collection.count()
224
-
225
- return {
226
- "total_documents": total_docs,
227
- "similarity_threshold": SIMILARITY_THRESHOLD,
228
- "top_k": TOP_K_RESULTS
229
- }
230
- except Exception as e:
231
- return {"error": str(e)}
 
1
  """
2
+ Cache Manager - Gère Hit/Miss et distillation locale
3
  """
4
 
5
  import numpy as np
6
+ from typing import Dict, List, Any
7
  import uuid
8
  from datetime import datetime
9
  from config import DISTANCE_THRESHOLD, TOP_K_RESULTS, CONFIDENCE_THRESHOLD_WARNING
 
14
  Args:
15
  chroma_collection: Collection ChromaDB
16
  encoder_fn: Fonction pour encoder du texte en embedding
17
+ threshold: Custom similarity threshold
18
  """
19
  self.collection = chroma_collection
20
  self.encoder_fn = encoder_fn
21
  self.threshold = threshold if threshold is not None else DISTANCE_THRESHOLD
22
 
23
  def calculate_confidence(self, distances: List[float]) -> float:
24
+ """Convertit la distance Chroma (Cosine) en score de confiance [0, 1]."""
 
 
 
 
 
 
25
  if not distances:
26
  return 0.0
27
+ # Avec hnsw:space="cosine", distance = 1 - similarity.
28
+ # Donc Similarity = 1 - distance.
29
  avg_distance = np.mean(distances)
30
+ return max(0.0, 1.0 - avg_distance)
 
 
 
 
 
 
 
31
 
32
  def query_cache(self, code: str, context: Dict[str, Any]) -> Dict[str, Any]:
33
  """
34
+ Recherche dans le cache (Pipeline Hybride: Exact Match -> Vector Search -> Code Comparison)
 
 
 
 
 
 
 
 
 
 
 
 
35
  """
36
 
37
+ # 1. CHECK RAPIDE (String Exact Match)
38
  try:
 
39
  if len(code) < 5000:
40
+ exact_matches = self.collection.get(where={"code": code}, limit=1)
 
 
 
41
  if exact_matches and len(exact_matches['ids']) > 0:
 
42
  return {
43
  "status": "perfect_match",
44
  "results": [{
 
48
  "rank": 1,
49
  "metadata": exact_matches['metadatas'][0]
50
  }],
 
51
  "confidence": 1.0,
 
52
  "needs_warning": False,
53
+ "closest_distance": 0.0
 
 
54
  }
55
  except Exception as e:
56
+ print(f"Warning exact match check: {e}")
57
 
58
+ # 2. RETRIEVAL (Vectorielle)
 
 
59
  query_embedding = self.encoder_fn(code)
 
60
  query_results = self.collection.query(
61
  query_embeddings=[query_embedding],
62
  n_results=TOP_K_RESULTS
 
66
  documents = query_results['documents'][0] if query_results['documents'] else []
67
  metadatas = query_results['metadatas'][0] if query_results['metadatas'] else []
68
 
69
+ # 3. ANALYSE (Similarity Check)
 
 
70
  perfect_code_match = False
71
+ code_similarity = 0.0
72
 
73
+ # Vérification sémantique du code sur le meilleur candidat
74
  if metadatas and metadatas[0].get('code'):
75
  ref_code = metadatas[0].get('code')
76
  if ref_code and ref_code != 'N/A':
77
+ # On encode le code de référence pour comparer avec le code d'entrée
78
  ref_code_embedding = self.encoder_fn(ref_code)
79
+ # Produit scalaire (approximatif si vecteurs normalisés)
80
  code_similarity = float(np.dot(query_embedding, ref_code_embedding))
 
 
81
  if code_similarity > 0.95:
82
  perfect_code_match = True
83
 
84
+ # 4. DÉCISION
 
 
 
 
85
  is_hit = False
86
  hit_type = "miss"
87
 
88
  if perfect_code_match:
89
  is_hit = True
90
+ hit_type = "perfect_match"
91
  elif distances and distances[0] < self.threshold:
92
  is_hit = True
93
+ hit_type = "hit"
94
 
95
+ # Formatage des résultats
 
 
96
  formatted_results = []
97
+ for i, (feedback, metadata, dist) in enumerate(zip(documents, metadatas, distances)):
98
  formatted_results.append({
99
  "rank": i + 1,
100
  "feedback": feedback,
101
  "code": metadata.get('code', 'N/A'),
102
+ "distance": round(dist, 4),
103
  "metadata": metadata
104
  })
105
 
106
  if is_hit:
 
107
  confidence = self.calculate_confidence(distances)
108
+ if perfect_code_match: confidence = 1.0
109
+
 
110
  return {
111
  "status": hit_type,
112
  "results": formatted_results,
113
+ "confidence": round(confidence, 3),
 
 
 
114
  "needs_warning": False if perfect_code_match else (confidence < CONFIDENCE_THRESHOLD_WARNING),
115
+ "closest_distance": round(distances[0], 4)
 
 
 
116
  }
 
117
  else:
 
118
  return {
119
  "status": "miss",
120
+ "results": formatted_results,
 
121
  "confidence": 0.0,
 
122
  "needs_warning": False,
 
 
123
  "closest_distance": round(distances[0], 4) if distances else 1.0
124
  }
125
+
126
  def add_to_cache(self, code: str, feedback: str, metadata: Dict[str, Any], embedding: List[float]) -> bool:
127
  """
128
+ Ajoute au cache local pour la session courante.
 
 
 
 
 
 
 
 
 
129
  """
130
  try:
131
+ doc_id = f"learned_{uuid.uuid4().hex[:8]}"
132
+
133
+ # Nettoyage des métadonnées (Chroma n'aime pas les listes/None)
134
+ safe_metadata = {
135
+ "code": code[:10000], # Limite de taille
136
  "timestamp": datetime.now().isoformat(),
137
+ "source": "active_learning",
138
+ "theme": str(metadata.get("theme", "")),
139
+ "difficulty": str(metadata.get("difficulty", ""))
140
  }
141
 
142
  self.collection.add(
143
  embeddings=[embedding],
144
  documents=[feedback],
145
+ metadatas=[safe_metadata],
146
  ids=[doc_id]
147
  )
148
+ print(f"✅ Learned new feedback: {doc_id}")
149
  return True
150
 
151
  except Exception as e:
152
+ print(f"Error adding to cache: {e}")
153
+ return False