""" Streamlit RAG Viewer with Intelligent Cache (Static RAG Mode) """ import streamlit as st import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModel import chromadb from pathlib import Path import json import time import logging import sys import os from huggingface_hub import login, snapshot_download # Import custom modules from cache_manager import CacheManager from deepseek_caller import DeepSeekCaller from stats_logger import StatsLogger from config import DISTANCE_THRESHOLD from utils import load_css # ========================================== # PAGE CONFIG # ========================================== st.set_page_config( page_title="RAG Feedback System", page_icon="🧠", layout="wide", initial_sidebar_state="expanded" ) # Configuration of the HF Dataset containing the Chroma DB DATASET_ID = "matis35/chroma-rag-storage" REPO_FOLDER = "chroma_db_storage" LOCAL_CACHE_DIR = Path("./chroma_cache") # ========================================== # CUSTOM CSS # ========================================== load_css("assets/style.css") # ========================================== # STATE MANAGEMENT # ========================================== if 'model_loaded' not in st.session_state: st.session_state.model_loaded = False if 'db_initialized' not in st.session_state: st.session_state.db_initialized = False if 'cache_manager' not in st.session_state: st.session_state.cache_manager = None if 'deepseek_caller' not in st.session_state: st.session_state.deepseek_caller = None if 'stats_logger' not in st.session_state: st.session_state.stats_logger = StatsLogger() # ========================================== # SETUP & LOGGING # ========================================== logging.basicConfig( level=logging.INFO, format='%(asctime)s | %(levelname)s | %(message)s', datefmt='%H:%M:%S', handlers=[logging.StreamHandler(sys.stdout)] ) logger = logging.getLogger("FFGen_System") # HF Authentication hf_token = os.environ.get("HF_TOKEN") if not hf_token and "HF_TOKEN" in st.secrets: hf_token = st.secrets["HF_TOKEN"] if hf_token: login(token=hf_token) # ========================================== # CORE FUNCTIONS # ========================================== @st.cache_resource def load_full_model(model_path: str): """Load embedding model (Hugging Face)""" st.info(f"Loading embedding model from: {model_path}...") try: tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModel.from_pretrained( model_path, trust_remote_code=True, device_map="auto" ) model.eval() return model, tokenizer except Exception as e: st.error(f"Failed to load model: {e}") return None, None def encode_text(text: str, model, tokenizer): """Generate normalized embedding""" device = next(model.parameters()).device inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) embeddings = outputs.last_hidden_state.mean(dim=1) embeddings = F.normalize(embeddings, p=2, dim=1) return embeddings[0].cpu().numpy().tolist() @st.cache_resource def initialize_chromadb(): """ Download pre-calculated Chroma DB from Hugging Face. """ final_db_path = LOCAL_CACHE_DIR / REPO_FOLDER # 1. Download if missing print(f"📥 Checking/Downloading vector DB from {DATASET_ID}...") try: snapshot_download( repo_id=DATASET_ID, repo_type="dataset", local_dir=LOCAL_CACHE_DIR, allow_patterns=[f"{REPO_FOLDER}/*"], local_dir_use_symlinks=False ) print("✅ DB ready.") except Exception as e: st.error(f"Failed to download DB: {e}") raise e # 2. Connection print(f"🔌 Connecting to ChromaDB at {final_db_path}") client = chromadb.PersistentClient(path=str(final_db_path)) # 3. Verification try: collection = client.get_collection(name="feedbacks") print(f"📊 Collection loaded. Documents: {collection.count()}") except Exception as e: st.error("Collection 'feedbacks' not found in the downloaded DB.") raise e return client, collection # ========================================== # MAIN INTERFACE # ========================================== st.title("FFGEN") st.markdown("### Submit code and get instant feedback") # --- SIDEBAR --- with st.sidebar: st.header("System Configuration") # Model Config model_path = st.text_input("Embedding Model", value="matis35/feedbacker-2") st.divider() # Cache Sensitivity st.subheader("Cache Sensitivity") if 'custom_threshold' not in st.session_state: st.session_state.custom_threshold = DISTANCE_THRESHOLD custom_threshold = st.slider( "Distance Threshold", 0.1, 1.0, value=st.session_state.custom_threshold, step=0.05, help="If the distance between your code and the feedback is LOWER than this threshold, it is a HIT." ) # Explicit visual indication (English) st.markdown(f"**Rule:** Distance < `{custom_threshold:.2f}` = **HIT**") if custom_threshold != st.session_state.custom_threshold: st.session_state.custom_threshold = custom_threshold if st.session_state.get('cache_manager'): st.session_state.cache_manager.threshold = custom_threshold st.divider() # Active Caching Toggle (Renamed and Default False) enable_caching = st.checkbox( "Enable Active Caching", value=False, help="If checked, new feedbacks generated by DeepSeek will be added to the local cache for this session." ) st.divider() # Main Action Button start_btn = st.button("Load System", use_container_width=True, type="primary") if start_btn: # 1. Load Model with st.spinner("1/2 Loading Neural Model..."): model, tokenizer = load_full_model(model_path) if model: st.session_state.model = model st.session_state.tokenizer = tokenizer st.session_state.model_loaded = True else: st.stop() # 2. Download & Connect DB with st.spinner("2/2 Downloading & Connecting Vector DB..."): try: client, collection = initialize_chromadb() # Call without argument st.session_state.client = client st.session_state.collection = collection st.session_state.db_initialized = True # Init Cache Manager encoder_fn = lambda text: encode_text(text, model, tokenizer) st.session_state.cache_manager = CacheManager( collection, encoder_fn, threshold=st.session_state.custom_threshold ) # Init DeepSeek try: st.session_state.deepseek_caller = DeepSeekCaller() except: st.warning("DeepSeek key not found, generation disabled.") st.success("System Ready!") time.sleep(1) # Small delay to see success st.rerun() except Exception as e: st.error(f"Initialization Error: {e}") # --- MAIN LOGIC --- if st.session_state.db_initialized and st.session_state.cache_manager: # Submission Form with st.form("code_submission"): col1, col2 = st.columns([2, 1]) with col1: code_input = st.text_area("C Code", height=300, placeholder="int main() { ... }") with col2: theme = st.text_input("Theme", placeholder="e.g. Arrays") difficulty = st.selectbox("Difficulty", ["beginner", "intermediate", "advanced"]) error_cat = st.text_input("Error Type (Optional)") instructions = st.text_area("Instructions", placeholder="Function should return...") submit_btn = st.form_submit_button("Search Feedback", use_container_width=True) if submit_btn and code_input: start_time = time.time() # Force update threshold st.session_state.cache_manager.threshold = st.session_state.custom_threshold context = { "code": code_input, "theme": theme, "difficulty": difficulty, "error_category": error_cat, "instructions": instructions } # 1. Query Cache with st.spinner("Searching knowledge base..."): cache_result = st.session_state.cache_manager.query_cache(code_input, context) # Calculate timing elapsed_ms = (time.time() - start_time) * 1000 tokens_used = 0 status = cache_result['status'] # --- POP-UP NOTIFICATION (TOAST) --- # Different message based on result quality if status == 'perfect_match': st.toast("**Perfect Match!** Identical code found.", icon="🔥") elif status == 'code_hit': st.toast("**Code Hit!** Code structure is very similar.", icon="💻") elif status in ['feedback_hit', 'hit', 'semantic hit']: st.toast("**Feedback Hit!** Semantic relevance found.", icon="🧠") else: # Miss st.toast("**Cache Miss.** AI Generation in progress...", icon="⏳") # -------------------------------------- # --- MAIN DISPLAY (HIT/MISS) --- if status in ['perfect_match', 'code_hit', 'feedback_hit', 'hit', 'semantic hit']: msg_type = "success" hit_msg = f"Feedback found! ({status.replace('_', ' ').upper()})" else: msg_type = "warning" hit_msg = "No similar feedback found. Generating new..." if msg_type == "success": st.success(f"{hit_msg} in {elapsed_ms:.0f}ms (Confidence: {cache_result['confidence']:.2f})") best = cache_result['results'][0] st.markdown("### Retrieved Feedback") st.write(best['feedback']) with st.expander("See Reference Code"): st.code(best['code'], language='c') st.caption(f"Distance: {best['distance']:.4f}") # --- ANALYSIS SECTION (TOP-K) --- with st.expander(f"Detailed Analysis: Top-{len(cache_result['results'])} Candidates", expanded=False): st.markdown(f"**Current Distance Threshold:** `{st.session_state.custom_threshold}`") st.caption("Distance = User Code → Feedback Embedding (Bi-Encoder)") for res in cache_result['results']: rank = res['rank'] dist = res['distance'] # Color code for distance dist_color = "green" if dist < st.session_state.custom_threshold else "red" st.markdown(f"#### Rank #{rank} : :{dist_color}[Distance {dist:.4f}]") # Side-by-side comparison col_a, col_b = st.columns(2) with col_a: st.markdown("**Stored Feedback:**") st.info(res['feedback']) with col_b: st.markdown("**Reference Code:**") st.code(res['code'][:800] + ("..." if len(res['code']) > 800 else ""), language='c') st.divider() # --- GENERATION IF MISS --- if status == 'miss': if st.session_state.deepseek_caller: with st.spinner("Generating analysis with DeepSeek..."): gen_result = st.session_state.deepseek_caller.generate_feedback(context) elapsed_ms = (time.time() - start_time) * 1000 if 'feedback' in gen_result: feedback = gen_result['feedback'] tokens_used = gen_result.get('tokens_total', 0) st.markdown("### Generated Feedback") st.write(feedback) # LOG ACTIVE CACHING if enable_caching: with st.spinner("Saving to local session cache..."): emb = encode_text(feedback, st.session_state.model, st.session_state.tokenizer) st.session_state.cache_manager.add_to_cache( code=code_input, feedback=feedback, metadata=context, embedding=emb ) # Confirmation toast for caching st.toast("Feedback learned and added to cache!", icon="✅") # LOG MISS DETAILS st.session_state.stats_logger.log_cache_miss({ "code": code_input, "feedback": feedback, "theme": theme, "error_category": error_cat, "tokens_used": tokens_used }) else: st.error("Generation failed.") else: st.error("DeepSeek not configured.") # --- FINAL: LOG METRICS FOR DASHBOARD --- st.session_state.stats_logger.log_query({ "status": status, "confidence": cache_result['confidence'], "similarity_score": cache_result.get('closest_distance', 0.0) if status == 'miss' else cache_result['results'][0]['distance'], "response_time_ms": elapsed_ms, "deepseek_tokens": tokens_used, "theme": theme, "difficulty": difficulty, "error_category": error_cat }) else: st.info("Please load the system from the sidebar to start.")