Spaces:
Sleeping
Sleeping
| """ | |
| Streamlit RAG Viewer with Intelligent Cache (Static RAG Mode) | |
| """ | |
| import streamlit as st | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModel | |
| import chromadb | |
| from pathlib import Path | |
| import json | |
| import time | |
| import logging | |
| import sys | |
| import os | |
| from huggingface_hub import login, snapshot_download | |
| # Import custom modules | |
| from cache_manager import CacheManager | |
| from deepseek_caller import DeepSeekCaller | |
| from stats_logger import StatsLogger | |
| from config import DISTANCE_THRESHOLD | |
| from utils import load_css | |
| # ========================================== | |
| # PAGE CONFIG | |
| # ========================================== | |
| st.set_page_config( | |
| page_title="RAG Feedback System", | |
| page_icon="🧠", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Configuration of the HF Dataset containing the Chroma DB | |
| DATASET_ID = "matis35/chroma-rag-storage" | |
| REPO_FOLDER = "chroma_db_storage" | |
| LOCAL_CACHE_DIR = Path("./chroma_cache") | |
| # ========================================== | |
| # CUSTOM CSS | |
| # ========================================== | |
| load_css("assets/style.css") | |
| # ========================================== | |
| # STATE MANAGEMENT | |
| # ========================================== | |
| if 'model_loaded' not in st.session_state: st.session_state.model_loaded = False | |
| if 'db_initialized' not in st.session_state: st.session_state.db_initialized = False | |
| if 'cache_manager' not in st.session_state: st.session_state.cache_manager = None | |
| if 'deepseek_caller' not in st.session_state: st.session_state.deepseek_caller = None | |
| if 'stats_logger' not in st.session_state: st.session_state.stats_logger = StatsLogger() | |
| # ========================================== | |
| # SETUP & LOGGING | |
| # ========================================== | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s | %(levelname)s | %(message)s', | |
| datefmt='%H:%M:%S', | |
| handlers=[logging.StreamHandler(sys.stdout)] | |
| ) | |
| logger = logging.getLogger("FFGen_System") | |
| # HF Authentication | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if not hf_token and "HF_TOKEN" in st.secrets: | |
| hf_token = st.secrets["HF_TOKEN"] | |
| if hf_token: | |
| login(token=hf_token) | |
| # ========================================== | |
| # CORE FUNCTIONS | |
| # ========================================== | |
| def load_full_model(model_path: str): | |
| """Load embedding model (Hugging Face)""" | |
| st.info(f"Loading embedding model from: {model_path}...") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModel.from_pretrained( | |
| model_path, | |
| trust_remote_code=True, | |
| device_map="auto" | |
| ) | |
| model.eval() | |
| return model, tokenizer | |
| except Exception as e: | |
| st.error(f"Failed to load model: {e}") | |
| return None, None | |
| def encode_text(text: str, model, tokenizer): | |
| """Generate normalized embedding""" | |
| device = next(model.parameters()).device | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| embeddings = outputs.last_hidden_state.mean(dim=1) | |
| embeddings = F.normalize(embeddings, p=2, dim=1) | |
| return embeddings[0].cpu().numpy().tolist() | |
| def initialize_chromadb(): | |
| """ | |
| Download pre-calculated Chroma DB from Hugging Face. | |
| """ | |
| final_db_path = LOCAL_CACHE_DIR / REPO_FOLDER | |
| # 1. Download if missing | |
| print(f"📥 Checking/Downloading vector DB from {DATASET_ID}...") | |
| try: | |
| snapshot_download( | |
| repo_id=DATASET_ID, | |
| repo_type="dataset", | |
| local_dir=LOCAL_CACHE_DIR, | |
| allow_patterns=[f"{REPO_FOLDER}/*"], | |
| local_dir_use_symlinks=False | |
| ) | |
| print("✅ DB ready.") | |
| except Exception as e: | |
| st.error(f"Failed to download DB: {e}") | |
| raise e | |
| # 2. Connection | |
| print(f"🔌 Connecting to ChromaDB at {final_db_path}") | |
| client = chromadb.PersistentClient(path=str(final_db_path)) | |
| # 3. Verification | |
| try: | |
| collection = client.get_collection(name="feedbacks") | |
| print(f"📊 Collection loaded. Documents: {collection.count()}") | |
| except Exception as e: | |
| st.error("Collection 'feedbacks' not found in the downloaded DB.") | |
| raise e | |
| return client, collection | |
| # ========================================== | |
| # MAIN INTERFACE | |
| # ========================================== | |
| st.title("FFGEN") | |
| st.markdown("### Submit code and get instant feedback") | |
| # --- SIDEBAR --- | |
| with st.sidebar: | |
| st.header("System Configuration") | |
| # Model Config | |
| model_path = st.text_input("Embedding Model", value="matis35/feedbacker-2") | |
| st.divider() | |
| # Cache Sensitivity | |
| st.subheader("Cache Sensitivity") | |
| if 'custom_threshold' not in st.session_state: | |
| st.session_state.custom_threshold = DISTANCE_THRESHOLD | |
| custom_threshold = st.slider( | |
| "Distance Threshold", 0.1, 1.0, | |
| value=st.session_state.custom_threshold, step=0.05, | |
| help="If the distance between your code and the feedback is LOWER than this threshold, it is a HIT." | |
| ) | |
| # Explicit visual indication (English) | |
| st.markdown(f"**Rule:** Distance < `{custom_threshold:.2f}` = **HIT**") | |
| if custom_threshold != st.session_state.custom_threshold: | |
| st.session_state.custom_threshold = custom_threshold | |
| if st.session_state.get('cache_manager'): | |
| st.session_state.cache_manager.threshold = custom_threshold | |
| st.divider() | |
| # Active Caching Toggle (Renamed and Default False) | |
| enable_caching = st.checkbox( | |
| "Enable Active Caching", | |
| value=False, | |
| help="If checked, new feedbacks generated by DeepSeek will be added to the local cache for this session." | |
| ) | |
| st.divider() | |
| # Main Action Button | |
| start_btn = st.button("Load System", use_container_width=True, type="primary") | |
| if start_btn: | |
| # 1. Load Model | |
| with st.spinner("1/2 Loading Neural Model..."): | |
| model, tokenizer = load_full_model(model_path) | |
| if model: | |
| st.session_state.model = model | |
| st.session_state.tokenizer = tokenizer | |
| st.session_state.model_loaded = True | |
| else: | |
| st.stop() | |
| # 2. Download & Connect DB | |
| with st.spinner("2/2 Downloading & Connecting Vector DB..."): | |
| try: | |
| client, collection = initialize_chromadb() # Call without argument | |
| st.session_state.client = client | |
| st.session_state.collection = collection | |
| st.session_state.db_initialized = True | |
| # Init Cache Manager | |
| encoder_fn = lambda text: encode_text(text, model, tokenizer) | |
| st.session_state.cache_manager = CacheManager( | |
| collection, | |
| encoder_fn, | |
| threshold=st.session_state.custom_threshold | |
| ) | |
| # Init DeepSeek | |
| try: | |
| st.session_state.deepseek_caller = DeepSeekCaller() | |
| except: | |
| st.warning("DeepSeek key not found, generation disabled.") | |
| st.success("System Ready!") | |
| time.sleep(1) # Small delay to see success | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"Initialization Error: {e}") | |
| # --- MAIN LOGIC --- | |
| if st.session_state.db_initialized and st.session_state.cache_manager: | |
| # Submission Form | |
| with st.form("code_submission"): | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| code_input = st.text_area("C Code", height=300, placeholder="int main() { ... }") | |
| with col2: | |
| theme = st.text_input("Theme", placeholder="e.g. Arrays") | |
| difficulty = st.selectbox("Difficulty", ["beginner", "intermediate", "advanced"]) | |
| error_cat = st.text_input("Error Type (Optional)") | |
| instructions = st.text_area("Instructions", placeholder="Function should return...") | |
| submit_btn = st.form_submit_button("Search Feedback", use_container_width=True) | |
| if submit_btn and code_input: | |
| start_time = time.time() | |
| # Force update threshold | |
| st.session_state.cache_manager.threshold = st.session_state.custom_threshold | |
| context = { | |
| "code": code_input, "theme": theme, | |
| "difficulty": difficulty, "error_category": error_cat, | |
| "instructions": instructions | |
| } | |
| # 1. Query Cache | |
| with st.spinner("Searching knowledge base..."): | |
| cache_result = st.session_state.cache_manager.query_cache(code_input, context) | |
| # Calculate timing | |
| elapsed_ms = (time.time() - start_time) * 1000 | |
| tokens_used = 0 | |
| status = cache_result['status'] | |
| # --- POP-UP NOTIFICATION (TOAST) --- | |
| # Different message based on result quality | |
| if status == 'perfect_match': | |
| st.toast("**Perfect Match!** Identical code found.", icon="🔥") | |
| elif status == 'code_hit': | |
| st.toast("**Code Hit!** Code structure is very similar.", icon="💻") | |
| elif status in ['feedback_hit', 'hit', 'semantic hit']: | |
| st.toast("**Feedback Hit!** Semantic relevance found.", icon="🧠") | |
| else: # Miss | |
| st.toast("**Cache Miss.** AI Generation in progress...", icon="⏳") | |
| # -------------------------------------- | |
| # --- MAIN DISPLAY (HIT/MISS) --- | |
| if status in ['perfect_match', 'code_hit', 'feedback_hit', 'hit', 'semantic hit']: | |
| msg_type = "success" | |
| hit_msg = f"Feedback found! ({status.replace('_', ' ').upper()})" | |
| else: | |
| msg_type = "warning" | |
| hit_msg = "No similar feedback found. Generating new..." | |
| if msg_type == "success": | |
| st.success(f"{hit_msg} in {elapsed_ms:.0f}ms (Confidence: {cache_result['confidence']:.2f})") | |
| best = cache_result['results'][0] | |
| st.markdown("### Retrieved Feedback") | |
| st.write(best['feedback']) | |
| with st.expander("See Reference Code"): | |
| st.code(best['code'], language='c') | |
| st.caption(f"Distance: {best['distance']:.4f}") | |
| # --- ANALYSIS SECTION (TOP-K) --- | |
| with st.expander(f"Detailed Analysis: Top-{len(cache_result['results'])} Candidates", expanded=False): | |
| st.markdown(f"**Current Distance Threshold:** `{st.session_state.custom_threshold}`") | |
| st.caption("Distance = User Code → Feedback Embedding (Bi-Encoder)") | |
| for res in cache_result['results']: | |
| rank = res['rank'] | |
| dist = res['distance'] | |
| # Color code for distance | |
| dist_color = "green" if dist < st.session_state.custom_threshold else "red" | |
| st.markdown(f"#### Rank #{rank} : :{dist_color}[Distance {dist:.4f}]") | |
| # Side-by-side comparison | |
| col_a, col_b = st.columns(2) | |
| with col_a: | |
| st.markdown("**Stored Feedback:**") | |
| st.info(res['feedback']) | |
| with col_b: | |
| st.markdown("**Reference Code:**") | |
| st.code(res['code'][:800] + ("..." if len(res['code']) > 800 else ""), language='c') | |
| st.divider() | |
| # --- GENERATION IF MISS --- | |
| if status == 'miss': | |
| if st.session_state.deepseek_caller: | |
| with st.spinner("Generating analysis with DeepSeek..."): | |
| gen_result = st.session_state.deepseek_caller.generate_feedback(context) | |
| elapsed_ms = (time.time() - start_time) * 1000 | |
| if 'feedback' in gen_result: | |
| feedback = gen_result['feedback'] | |
| tokens_used = gen_result.get('tokens_total', 0) | |
| st.markdown("### Generated Feedback") | |
| st.write(feedback) | |
| # LOG ACTIVE CACHING | |
| if enable_caching: | |
| with st.spinner("Saving to local session cache..."): | |
| emb = encode_text(feedback, st.session_state.model, st.session_state.tokenizer) | |
| st.session_state.cache_manager.add_to_cache( | |
| code=code_input, | |
| feedback=feedback, | |
| metadata=context, | |
| embedding=emb | |
| ) | |
| # Confirmation toast for caching | |
| st.toast("Feedback learned and added to cache!", icon="✅") | |
| # LOG MISS DETAILS | |
| st.session_state.stats_logger.log_cache_miss({ | |
| "code": code_input, | |
| "feedback": feedback, | |
| "theme": theme, | |
| "error_category": error_cat, | |
| "tokens_used": tokens_used | |
| }) | |
| else: | |
| st.error("Generation failed.") | |
| else: | |
| st.error("DeepSeek not configured.") | |
| # --- FINAL: LOG METRICS FOR DASHBOARD --- | |
| st.session_state.stats_logger.log_query({ | |
| "status": status, | |
| "confidence": cache_result['confidence'], | |
| "similarity_score": cache_result.get('closest_distance', 0.0) if status == 'miss' else cache_result['results'][0]['distance'], | |
| "response_time_ms": elapsed_ms, | |
| "deepseek_tokens": tokens_used, | |
| "theme": theme, | |
| "difficulty": difficulty, | |
| "error_category": error_cat | |
| }) | |
| else: | |
| st.info("Please load the system from the sidebar to start.") |