FFGEN-Demo / app.py
Matis Codjia
Fix:default encoder
52d9664
"""
Streamlit RAG Viewer with Intelligent Cache (Static RAG Mode)
"""
import streamlit as st
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
import chromadb
from pathlib import Path
import json
import time
import logging
import sys
import os
from huggingface_hub import login, snapshot_download
# Import custom modules
from cache_manager import CacheManager
from deepseek_caller import DeepSeekCaller
from stats_logger import StatsLogger
from config import DISTANCE_THRESHOLD
from utils import load_css
# ==========================================
# PAGE CONFIG
# ==========================================
st.set_page_config(
page_title="RAG Feedback System",
page_icon="🧠",
layout="wide",
initial_sidebar_state="expanded"
)
# Configuration of the HF Dataset containing the Chroma DB
DATASET_ID = "matis35/chroma-rag-storage"
REPO_FOLDER = "chroma_db_storage"
LOCAL_CACHE_DIR = Path("./chroma_cache")
# ==========================================
# CUSTOM CSS
# ==========================================
load_css("assets/style.css")
# ==========================================
# STATE MANAGEMENT
# ==========================================
if 'model_loaded' not in st.session_state: st.session_state.model_loaded = False
if 'db_initialized' not in st.session_state: st.session_state.db_initialized = False
if 'cache_manager' not in st.session_state: st.session_state.cache_manager = None
if 'deepseek_caller' not in st.session_state: st.session_state.deepseek_caller = None
if 'stats_logger' not in st.session_state: st.session_state.stats_logger = StatsLogger()
# ==========================================
# SETUP & LOGGING
# ==========================================
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s | %(levelname)s | %(message)s',
datefmt='%H:%M:%S',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger("FFGen_System")
# HF Authentication
hf_token = os.environ.get("HF_TOKEN")
if not hf_token and "HF_TOKEN" in st.secrets:
hf_token = st.secrets["HF_TOKEN"]
if hf_token:
login(token=hf_token)
# ==========================================
# CORE FUNCTIONS
# ==========================================
@st.cache_resource
def load_full_model(model_path: str):
"""Load embedding model (Hugging Face)"""
st.info(f"Loading embedding model from: {model_path}...")
try:
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModel.from_pretrained(
model_path,
trust_remote_code=True,
device_map="auto"
)
model.eval()
return model, tokenizer
except Exception as e:
st.error(f"Failed to load model: {e}")
return None, None
def encode_text(text: str, model, tokenizer):
"""Generate normalized embedding"""
device = next(model.parameters()).device
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1)
embeddings = F.normalize(embeddings, p=2, dim=1)
return embeddings[0].cpu().numpy().tolist()
@st.cache_resource
def initialize_chromadb():
"""
Download pre-calculated Chroma DB from Hugging Face.
"""
final_db_path = LOCAL_CACHE_DIR / REPO_FOLDER
# 1. Download if missing
print(f"📥 Checking/Downloading vector DB from {DATASET_ID}...")
try:
snapshot_download(
repo_id=DATASET_ID,
repo_type="dataset",
local_dir=LOCAL_CACHE_DIR,
allow_patterns=[f"{REPO_FOLDER}/*"],
local_dir_use_symlinks=False
)
print("✅ DB ready.")
except Exception as e:
st.error(f"Failed to download DB: {e}")
raise e
# 2. Connection
print(f"🔌 Connecting to ChromaDB at {final_db_path}")
client = chromadb.PersistentClient(path=str(final_db_path))
# 3. Verification
try:
collection = client.get_collection(name="feedbacks")
print(f"📊 Collection loaded. Documents: {collection.count()}")
except Exception as e:
st.error("Collection 'feedbacks' not found in the downloaded DB.")
raise e
return client, collection
# ==========================================
# MAIN INTERFACE
# ==========================================
st.title("FFGEN")
st.markdown("### Submit code and get instant feedback")
# --- SIDEBAR ---
with st.sidebar:
st.header("System Configuration")
# Model Config
model_path = st.text_input("Embedding Model", value="matis35/feedbacker-2")
st.divider()
# Cache Sensitivity
st.subheader("Cache Sensitivity")
if 'custom_threshold' not in st.session_state:
st.session_state.custom_threshold = DISTANCE_THRESHOLD
custom_threshold = st.slider(
"Distance Threshold", 0.1, 1.0,
value=st.session_state.custom_threshold, step=0.05,
help="If the distance between your code and the feedback is LOWER than this threshold, it is a HIT."
)
# Explicit visual indication (English)
st.markdown(f"**Rule:** Distance < `{custom_threshold:.2f}` = **HIT**")
if custom_threshold != st.session_state.custom_threshold:
st.session_state.custom_threshold = custom_threshold
if st.session_state.get('cache_manager'):
st.session_state.cache_manager.threshold = custom_threshold
st.divider()
# Active Caching Toggle (Renamed and Default False)
enable_caching = st.checkbox(
"Enable Active Caching",
value=False,
help="If checked, new feedbacks generated by DeepSeek will be added to the local cache for this session."
)
st.divider()
# Main Action Button
start_btn = st.button("Load System", use_container_width=True, type="primary")
if start_btn:
# 1. Load Model
with st.spinner("1/2 Loading Neural Model..."):
model, tokenizer = load_full_model(model_path)
if model:
st.session_state.model = model
st.session_state.tokenizer = tokenizer
st.session_state.model_loaded = True
else:
st.stop()
# 2. Download & Connect DB
with st.spinner("2/2 Downloading & Connecting Vector DB..."):
try:
client, collection = initialize_chromadb() # Call without argument
st.session_state.client = client
st.session_state.collection = collection
st.session_state.db_initialized = True
# Init Cache Manager
encoder_fn = lambda text: encode_text(text, model, tokenizer)
st.session_state.cache_manager = CacheManager(
collection,
encoder_fn,
threshold=st.session_state.custom_threshold
)
# Init DeepSeek
try:
st.session_state.deepseek_caller = DeepSeekCaller()
except:
st.warning("DeepSeek key not found, generation disabled.")
st.success("System Ready!")
time.sleep(1) # Small delay to see success
st.rerun()
except Exception as e:
st.error(f"Initialization Error: {e}")
# --- MAIN LOGIC ---
if st.session_state.db_initialized and st.session_state.cache_manager:
# Submission Form
with st.form("code_submission"):
col1, col2 = st.columns([2, 1])
with col1:
code_input = st.text_area("C Code", height=300, placeholder="int main() { ... }")
with col2:
theme = st.text_input("Theme", placeholder="e.g. Arrays")
difficulty = st.selectbox("Difficulty", ["beginner", "intermediate", "advanced"])
error_cat = st.text_input("Error Type (Optional)")
instructions = st.text_area("Instructions", placeholder="Function should return...")
submit_btn = st.form_submit_button("Search Feedback", use_container_width=True)
if submit_btn and code_input:
start_time = time.time()
# Force update threshold
st.session_state.cache_manager.threshold = st.session_state.custom_threshold
context = {
"code": code_input, "theme": theme,
"difficulty": difficulty, "error_category": error_cat,
"instructions": instructions
}
# 1. Query Cache
with st.spinner("Searching knowledge base..."):
cache_result = st.session_state.cache_manager.query_cache(code_input, context)
# Calculate timing
elapsed_ms = (time.time() - start_time) * 1000
tokens_used = 0
status = cache_result['status']
# --- POP-UP NOTIFICATION (TOAST) ---
# Different message based on result quality
if status == 'perfect_match':
st.toast("**Perfect Match!** Identical code found.", icon="🔥")
elif status == 'code_hit':
st.toast("**Code Hit!** Code structure is very similar.", icon="💻")
elif status in ['feedback_hit', 'hit', 'semantic hit']:
st.toast("**Feedback Hit!** Semantic relevance found.", icon="🧠")
else: # Miss
st.toast("**Cache Miss.** AI Generation in progress...", icon="⏳")
# --------------------------------------
# --- MAIN DISPLAY (HIT/MISS) ---
if status in ['perfect_match', 'code_hit', 'feedback_hit', 'hit', 'semantic hit']:
msg_type = "success"
hit_msg = f"Feedback found! ({status.replace('_', ' ').upper()})"
else:
msg_type = "warning"
hit_msg = "No similar feedback found. Generating new..."
if msg_type == "success":
st.success(f"{hit_msg} in {elapsed_ms:.0f}ms (Confidence: {cache_result['confidence']:.2f})")
best = cache_result['results'][0]
st.markdown("### Retrieved Feedback")
st.write(best['feedback'])
with st.expander("See Reference Code"):
st.code(best['code'], language='c')
st.caption(f"Distance: {best['distance']:.4f}")
# --- ANALYSIS SECTION (TOP-K) ---
with st.expander(f"Detailed Analysis: Top-{len(cache_result['results'])} Candidates", expanded=False):
st.markdown(f"**Current Distance Threshold:** `{st.session_state.custom_threshold}`")
st.caption("Distance = User Code → Feedback Embedding (Bi-Encoder)")
for res in cache_result['results']:
rank = res['rank']
dist = res['distance']
# Color code for distance
dist_color = "green" if dist < st.session_state.custom_threshold else "red"
st.markdown(f"#### Rank #{rank} : :{dist_color}[Distance {dist:.4f}]")
# Side-by-side comparison
col_a, col_b = st.columns(2)
with col_a:
st.markdown("**Stored Feedback:**")
st.info(res['feedback'])
with col_b:
st.markdown("**Reference Code:**")
st.code(res['code'][:800] + ("..." if len(res['code']) > 800 else ""), language='c')
st.divider()
# --- GENERATION IF MISS ---
if status == 'miss':
if st.session_state.deepseek_caller:
with st.spinner("Generating analysis with DeepSeek..."):
gen_result = st.session_state.deepseek_caller.generate_feedback(context)
elapsed_ms = (time.time() - start_time) * 1000
if 'feedback' in gen_result:
feedback = gen_result['feedback']
tokens_used = gen_result.get('tokens_total', 0)
st.markdown("### Generated Feedback")
st.write(feedback)
# LOG ACTIVE CACHING
if enable_caching:
with st.spinner("Saving to local session cache..."):
emb = encode_text(feedback, st.session_state.model, st.session_state.tokenizer)
st.session_state.cache_manager.add_to_cache(
code=code_input,
feedback=feedback,
metadata=context,
embedding=emb
)
# Confirmation toast for caching
st.toast("Feedback learned and added to cache!", icon="✅")
# LOG MISS DETAILS
st.session_state.stats_logger.log_cache_miss({
"code": code_input,
"feedback": feedback,
"theme": theme,
"error_category": error_cat,
"tokens_used": tokens_used
})
else:
st.error("Generation failed.")
else:
st.error("DeepSeek not configured.")
# --- FINAL: LOG METRICS FOR DASHBOARD ---
st.session_state.stats_logger.log_query({
"status": status,
"confidence": cache_result['confidence'],
"similarity_score": cache_result.get('closest_distance', 0.0) if status == 'miss' else cache_result['results'][0]['distance'],
"response_time_ms": elapsed_ms,
"deepseek_tokens": tokens_used,
"theme": theme,
"difficulty": difficulty,
"error_category": error_cat
})
else:
st.info("Please load the system from the sidebar to start.")