File size: 14,316 Bytes
27bbd47
c6390d0
27bbd47
 
 
 
 
 
 
 
 
 
 
 
3ee62c8
 
 
c6390d0
27bbd47
 
 
6290fdd
27bbd47
 
 
 
 
 
 
c6390d0
27bbd47
 
 
 
c6390d0
3a00bdd
3ee62c8
3a00bdd
 
27bbd47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ee62c8
27bbd47
 
 
 
 
3ee62c8
27bbd47
 
3ee62c8
c6390d0
8c324db
3ee62c8
 
27bbd47
8c324db
 
3ee62c8
 
 
 
 
27bbd47
 
c6390d0
3ee62c8
27bbd47
 
 
 
 
 
 
 
 
 
 
 
 
3ee62c8
27bbd47
 
 
c6390d0
27bbd47
 
 
 
 
 
 
 
 
3ee62c8
3a00bdd
 
c6390d0
3a00bdd
 
 
c6390d0
 
5ffddea
 
 
 
 
 
 
 
c6390d0
5ffddea
 
 
3a00bdd
c6390d0
 
3a00bdd
 
c6390d0
27bbd47
3a00bdd
c6390d0
3a00bdd
3ee62c8
3a00bdd
27bbd47
 
3ee62c8
27bbd47
3ee62c8
27bbd47
 
 
c6390d0
27bbd47
3ee62c8
27bbd47
a68d3b5
27bbd47
3ee62c8
52d9664
3ee62c8
27bbd47
3ee62c8
 
 
27bbd47
6290fdd
27bbd47
 
cfbf8a2
3ee62c8
c6390d0
27bbd47
3ee62c8
c6390d0
 
cfbf8a2
27bbd47
 
 
 
cfbf8a2
27bbd47
 
c6390d0
 
 
a68d3b5
3ee62c8
 
27bbd47
3ee62c8
27bbd47
3ee62c8
a68d3b5
27bbd47
3ee62c8
 
c6390d0
27bbd47
 
 
 
 
 
 
 
3ee62c8
c6390d0
27bbd47
c6390d0
27bbd47
 
 
3ee62c8
 
27bbd47
 
 
 
 
 
3ee62c8
 
27bbd47
 
 
3ee62c8
 
c6390d0
 
3ee62c8
 
 
 
27bbd47
3ee62c8
27bbd47
 
3ee62c8
c6390d0
27bbd47
 
 
3ee62c8
27bbd47
3ee62c8
c6390d0
3ee62c8
 
 
 
27bbd47
 
 
c6390d0
 
874e94f
c6390d0
27bbd47
3ee62c8
 
 
27bbd47
 
3ee62c8
a68d3b5
27bbd47
3ee62c8
cfbf8a2
 
 
 
27bbd47
c6390d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ee62c8
 
a68d3b5
3ee62c8
 
 
 
 
27bbd47
c6390d0
 
 
 
3ee62c8
c6390d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27bbd47
a68d3b5
3ee62c8
cfbf8a2
3ee62c8
 
 
cfbf8a2
 
c6390d0
27bbd47
3ee62c8
c6390d0
 
a68d3b5
3ee62c8
 
 
 
 
 
 
c6390d0
 
cfbf8a2
c6390d0
cfbf8a2
 
 
 
 
 
 
 
27bbd47
3ee62c8
27bbd47
3ee62c8
27bbd47
cfbf8a2
 
 
 
 
 
 
 
 
 
 
 
27bbd47
a68d3b5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
"""
Streamlit RAG Viewer with Intelligent Cache (Static RAG Mode)
"""

import streamlit as st
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
import chromadb
from pathlib import Path
import json
import time
import logging
import sys
import os
from huggingface_hub import login, snapshot_download

# Import custom modules
from cache_manager import CacheManager
from deepseek_caller import DeepSeekCaller
from stats_logger import StatsLogger
from config import DISTANCE_THRESHOLD
from utils import load_css

# ==========================================
# PAGE CONFIG
# ==========================================
st.set_page_config(
    page_title="RAG Feedback System",
    page_icon="🧠",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Configuration of the HF Dataset containing the Chroma DB
DATASET_ID = "matis35/chroma-rag-storage"
REPO_FOLDER = "chroma_db_storage"
LOCAL_CACHE_DIR = Path("./chroma_cache")

# ==========================================
# CUSTOM CSS
# ==========================================
load_css("assets/style.css")

# ==========================================
# STATE MANAGEMENT
# ==========================================
if 'model_loaded' not in st.session_state: st.session_state.model_loaded = False
if 'db_initialized' not in st.session_state: st.session_state.db_initialized = False
if 'cache_manager' not in st.session_state: st.session_state.cache_manager = None
if 'deepseek_caller' not in st.session_state: st.session_state.deepseek_caller = None
if 'stats_logger' not in st.session_state: st.session_state.stats_logger = StatsLogger()

# ==========================================
# SETUP & LOGGING
# ==========================================
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s | %(levelname)s | %(message)s',
    datefmt='%H:%M:%S',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger("FFGen_System")

# HF Authentication
hf_token = os.environ.get("HF_TOKEN")
if not hf_token and "HF_TOKEN" in st.secrets:
    hf_token = st.secrets["HF_TOKEN"]

if hf_token:
    login(token=hf_token)

# ==========================================
# CORE FUNCTIONS
# ==========================================

@st.cache_resource
def load_full_model(model_path: str):
    """Load embedding model (Hugging Face)"""
    st.info(f"Loading embedding model from: {model_path}...")
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        model = AutoModel.from_pretrained(
            model_path,
            trust_remote_code=True,
            device_map="auto"
        )
        model.eval()
        return model, tokenizer
    except Exception as e:
        st.error(f"Failed to load model: {e}")
        return None, None

def encode_text(text: str, model, tokenizer):
    """Generate normalized embedding"""
    device = next(model.parameters()).device
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
        embeddings = outputs.last_hidden_state.mean(dim=1)
        embeddings = F.normalize(embeddings, p=2, dim=1)
    return embeddings[0].cpu().numpy().tolist()

@st.cache_resource
def initialize_chromadb():
    """
    Download pre-calculated Chroma DB from Hugging Face.
    """
    final_db_path = LOCAL_CACHE_DIR / REPO_FOLDER
    
    # 1. Download if missing
    print(f"πŸ“₯ Checking/Downloading vector DB from {DATASET_ID}...")
    try:
        snapshot_download(
            repo_id=DATASET_ID,
            repo_type="dataset",
            local_dir=LOCAL_CACHE_DIR,
            allow_patterns=[f"{REPO_FOLDER}/*"],
            local_dir_use_symlinks=False
        )
        print("βœ… DB ready.")
    except Exception as e:
        st.error(f"Failed to download DB: {e}")
        raise e
            
    # 2. Connection
    print(f"πŸ”Œ Connecting to ChromaDB at {final_db_path}")
    client = chromadb.PersistentClient(path=str(final_db_path))

    # 3. Verification
    try:
        collection = client.get_collection(name="feedbacks")
        print(f"πŸ“Š Collection loaded. Documents: {collection.count()}")
    except Exception as e:
        st.error("Collection 'feedbacks' not found in the downloaded DB.")
        raise e

    return client, collection

# ==========================================
# MAIN INTERFACE
# ==========================================

st.title("FFGEN")
st.markdown("### Submit code and get instant feedback")

# --- SIDEBAR ---
with st.sidebar:
    st.header("System Configuration")

    # Model Config
    model_path = st.text_input("Embedding Model", value="matis35/feedbacker-2")
    
    st.divider()
    
    # Cache Sensitivity
    st.subheader("Cache Sensitivity")
    if 'custom_threshold' not in st.session_state:
        st.session_state.custom_threshold = DISTANCE_THRESHOLD

    custom_threshold = st.slider(
        "Distance Threshold", 0.1, 1.0, 
        value=st.session_state.custom_threshold, step=0.05,
        help="If the distance between your code and the feedback is LOWER than this threshold, it is a HIT."
    )
    
    # Explicit visual indication (English)
    st.markdown(f"**Rule:** Distance < `{custom_threshold:.2f}` = **HIT**")
    
    if custom_threshold != st.session_state.custom_threshold:
        st.session_state.custom_threshold = custom_threshold
        if st.session_state.get('cache_manager'):
            st.session_state.cache_manager.threshold = custom_threshold
    
    st.divider()

    # Active Caching Toggle (Renamed and Default False)
    enable_caching = st.checkbox(
        "Enable Active Caching", 
        value=False,
        help="If checked, new feedbacks generated by DeepSeek will be added to the local cache for this session."
    )

    st.divider()

    # Main Action Button
    start_btn = st.button("Load System", use_container_width=True, type="primary")

    if start_btn:
        # 1. Load Model
        with st.spinner("1/2 Loading Neural Model..."):
            model, tokenizer = load_full_model(model_path)
            if model:
                st.session_state.model = model
                st.session_state.tokenizer = tokenizer
                st.session_state.model_loaded = True
            else:
                st.stop()

        # 2. Download & Connect DB
        with st.spinner("2/2 Downloading & Connecting Vector DB..."):
            try:
                client, collection = initialize_chromadb() # Call without argument
                st.session_state.client = client
                st.session_state.collection = collection
                st.session_state.db_initialized = True
                
                # Init Cache Manager
                encoder_fn = lambda text: encode_text(text, model, tokenizer)
                st.session_state.cache_manager = CacheManager(
                    collection,
                    encoder_fn,
                    threshold=st.session_state.custom_threshold
                )
                
                # Init DeepSeek
                try:
                    st.session_state.deepseek_caller = DeepSeekCaller()
                except:
                    st.warning("DeepSeek key not found, generation disabled.")
                
                st.success("System Ready!")
                time.sleep(1) # Small delay to see success
                st.rerun()
                
            except Exception as e:
                st.error(f"Initialization Error: {e}")

# --- MAIN LOGIC ---

if st.session_state.db_initialized and st.session_state.cache_manager:
    
    # Submission Form
    with st.form("code_submission"):
        col1, col2 = st.columns([2, 1])
        with col1:
            code_input = st.text_area("C Code", height=300, placeholder="int main() { ... }")
        with col2:
            theme = st.text_input("Theme", placeholder="e.g. Arrays")
            difficulty = st.selectbox("Difficulty", ["beginner", "intermediate", "advanced"])
            error_cat = st.text_input("Error Type (Optional)")
        
        instructions = st.text_area("Instructions", placeholder="Function should return...")
        submit_btn = st.form_submit_button("Search Feedback", use_container_width=True)

    if submit_btn and code_input:
        start_time = time.time()
        
        # Force update threshold
        st.session_state.cache_manager.threshold = st.session_state.custom_threshold

        context = {
            "code": code_input, "theme": theme, 
            "difficulty": difficulty, "error_category": error_cat,
            "instructions": instructions
        }

        # 1. Query Cache
        with st.spinner("Searching knowledge base..."):
            cache_result = st.session_state.cache_manager.query_cache(code_input, context)
        
        # Calculate timing
        elapsed_ms = (time.time() - start_time) * 1000
        tokens_used = 0
        status = cache_result['status']

        # --- POP-UP NOTIFICATION (TOAST) ---
        # Different message based on result quality
        if status == 'perfect_match':
            st.toast("**Perfect Match!** Identical code found.", icon="πŸ”₯")
        elif status == 'code_hit':
            st.toast("**Code Hit!** Code structure is very similar.", icon="πŸ’»")
        elif status in ['feedback_hit', 'hit', 'semantic hit']:
            st.toast("**Feedback Hit!** Semantic relevance found.", icon="🧠")
        else: # Miss
            st.toast("**Cache Miss.** AI Generation in progress...", icon="⏳")
        # --------------------------------------

        # --- MAIN DISPLAY (HIT/MISS) ---
        if status in ['perfect_match', 'code_hit', 'feedback_hit', 'hit', 'semantic hit']:
            msg_type = "success"
            hit_msg = f"Feedback found! ({status.replace('_', ' ').upper()})"
        else:
            msg_type = "warning"
            hit_msg = "No similar feedback found. Generating new..."

        if msg_type == "success":
            st.success(f"{hit_msg} in {elapsed_ms:.0f}ms (Confidence: {cache_result['confidence']:.2f})")
            
            best = cache_result['results'][0]
            st.markdown("### Retrieved Feedback")
            st.write(best['feedback'])
            
            with st.expander("See Reference Code"):
                st.code(best['code'], language='c')
                st.caption(f"Distance: {best['distance']:.4f}")

        # --- ANALYSIS SECTION (TOP-K) ---
        with st.expander(f"Detailed Analysis: Top-{len(cache_result['results'])} Candidates", expanded=False):
            st.markdown(f"**Current Distance Threshold:** `{st.session_state.custom_threshold}`")
            st.caption("Distance = User Code β†’ Feedback Embedding (Bi-Encoder)")
            
            for res in cache_result['results']:
                rank = res['rank']
                dist = res['distance']
                
                # Color code for distance
                dist_color = "green" if dist < st.session_state.custom_threshold else "red"
                
                st.markdown(f"#### Rank #{rank} : :{dist_color}[Distance {dist:.4f}]")
                
                # Side-by-side comparison
                col_a, col_b = st.columns(2)
                with col_a:
                    st.markdown("**Stored Feedback:**")
                    st.info(res['feedback'])
                with col_b:
                    st.markdown("**Reference Code:**")
                    st.code(res['code'][:800] + ("..." if len(res['code']) > 800 else ""), language='c')
                
                st.divider()

        # --- GENERATION IF MISS ---
        if status == 'miss':
            if st.session_state.deepseek_caller:
                with st.spinner("Generating analysis with DeepSeek..."):
                    gen_result = st.session_state.deepseek_caller.generate_feedback(context)
                    elapsed_ms = (time.time() - start_time) * 1000
                
                if 'feedback' in gen_result:
                    feedback = gen_result['feedback']
                    tokens_used = gen_result.get('tokens_total', 0)
                    
                    st.markdown("### Generated Feedback")
                    st.write(feedback)
                    
                    # LOG ACTIVE CACHING
                    if enable_caching:
                        with st.spinner("Saving to local session cache..."):
                            emb = encode_text(feedback, st.session_state.model, st.session_state.tokenizer)
                            st.session_state.cache_manager.add_to_cache(
                                code=code_input,
                                feedback=feedback,
                                metadata=context,
                                embedding=emb
                            )
                        # Confirmation toast for caching
                        st.toast("Feedback learned and added to cache!", icon="βœ…")

                    # LOG MISS DETAILS
                    st.session_state.stats_logger.log_cache_miss({
                        "code": code_input,
                        "feedback": feedback,
                        "theme": theme,
                        "error_category": error_cat,
                        "tokens_used": tokens_used
                    })

                else:
                    st.error("Generation failed.")
            else:
                st.error("DeepSeek not configured.")

        # --- FINAL: LOG METRICS FOR DASHBOARD ---
        st.session_state.stats_logger.log_query({
            "status": status,
            "confidence": cache_result['confidence'],
            "similarity_score": cache_result.get('closest_distance', 0.0) if status == 'miss' else cache_result['results'][0]['distance'],
            "response_time_ms": elapsed_ms,
            "deepseek_tokens": tokens_used,
            "theme": theme,
            "difficulty": difficulty,
            "error_category": error_cat
        })

else:
    st.info("Please load the system from the sidebar to start.")