Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModel | |
| import pandas as pd | |
| import sys | |
| import os | |
| import shutil | |
| from pathlib import Path | |
| import chromadb | |
| from chromadb.config import Settings | |
| import uuid | |
| import tempfile | |
| # --- Add scripts to path --- | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) | |
| from scripts.core.ingestion.ingest import GitCrawler | |
| from scripts.core.ingestion.chunk import RepoChunker | |
| # --- Configuration --- | |
| BASELINE_MODEL = "microsoft/codebert-base" | |
| FINETUNED_MODEL = "shubharuidas/codebert-base-code-embed-mrl-langchain-langgraph" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DB_DIR = Path(os.path.abspath("data/chroma_db_comparison")) | |
| DB_DIR.mkdir(parents=True, exist_ok=True) | |
| print(f"Loading models on {DEVICE}...") | |
| print("1. Loading baseline model...") | |
| baseline_tokenizer = AutoTokenizer.from_pretrained(BASELINE_MODEL) | |
| baseline_model = AutoModel.from_pretrained(BASELINE_MODEL) | |
| baseline_model.to(DEVICE) | |
| baseline_model.eval() | |
| print("2. Loading fine-tuned model...") | |
| finetuned_tokenizer = AutoTokenizer.from_pretrained(FINETUNED_MODEL) | |
| finetuned_model = AutoModel.from_pretrained(FINETUNED_MODEL) | |
| finetuned_model.to(DEVICE) | |
| finetuned_model.eval() | |
| print("Both models loaded!") | |
| # --- ChromaDB Setup --- | |
| chroma_client = chromadb.PersistentClient(path=str(DB_DIR)) | |
| baseline_collection = chroma_client.get_or_create_collection(name="baseline_rag", metadata={"hnsw:space": "cosine"}) | |
| finetuned_collection = chroma_client.get_or_create_collection(name="finetuned_rag", metadata={"hnsw:space": "cosine"}) | |
| # --- Embedding Functions --- | |
| def compute_baseline_embeddings(text_list): | |
| if not text_list: return None | |
| inputs = baseline_tokenizer(text_list, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE) | |
| with torch.no_grad(): | |
| out = baseline_model(**inputs) | |
| emb = out.last_hidden_state.mean(dim=1) | |
| return F.normalize(emb, p=2, dim=1) | |
| def compute_finetuned_embeddings(text_list): | |
| if not text_list: return None | |
| inputs = finetuned_tokenizer(text_list, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE) | |
| with torch.no_grad(): | |
| out = finetuned_model(**inputs) | |
| emb = out.last_hidden_state.mean(dim=1) | |
| return F.normalize(emb, p=2, dim=1) | |
| # --- Reset Functions --- | |
| def reset_baseline(): | |
| chroma_client.delete_collection("baseline_rag") | |
| global baseline_collection | |
| baseline_collection = chroma_client.get_or_create_collection(name="baseline_rag", metadata={"hnsw:space": "cosine"}) | |
| return "Baseline database reset." | |
| def reset_finetuned(): | |
| chroma_client.delete_collection("finetuned_rag") | |
| global finetuned_collection | |
| finetuned_collection = chroma_client.get_or_create_collection(name="finetuned_rag", metadata={"hnsw:space": "cosine"}) | |
| return "Fine-tuned database reset." | |
| # --- Database Inspector Functions --- | |
| def list_baseline_files(): | |
| count = baseline_collection.count() | |
| if count == 0: | |
| return [["No data indexed yet", "-", "-"]] | |
| try: | |
| data = baseline_collection.get(limit=min(count, 1000), include=["metadatas"]) | |
| file_stats = {} | |
| for meta in data['metadatas']: | |
| fname = meta.get("file_name", "unknown") | |
| url = meta.get("url", "unknown") | |
| if fname not in file_stats: | |
| file_stats[fname] = {"count": 0, "url": url} | |
| file_stats[fname]["count"] += 1 | |
| results = [[fname, stats["count"], stats["url"]] for fname, stats in file_stats.items()] | |
| return sorted(results, key=lambda x: x[1], reverse=True) | |
| except Exception as e: | |
| return [[f"Error: {str(e)}", "-", "-"]] | |
| def list_finetuned_files(): | |
| count = finetuned_collection.count() | |
| if count == 0: | |
| return [["No data indexed yet", "-", "-"]] | |
| try: | |
| data = finetuned_collection.get(limit=min(count, 1000), include=["metadatas"]) | |
| file_stats = {} | |
| for meta in data['metadatas']: | |
| fname = meta.get("file_name", "unknown") | |
| url = meta.get("url", "unknown") | |
| if fname not in file_stats: | |
| file_stats[fname] = {"count": 0, "url": url} | |
| file_stats[fname]["count"] += 1 | |
| results = [[fname, stats["count"], stats["url"]] for fname, stats in file_stats.items()] | |
| return sorted(results, key=lambda x: x[1], reverse=True) | |
| except Exception as e: | |
| return [[f"Error: {str(e)}", "-", "-"]] | |
| # --- Search Functions --- | |
| def search_baseline(query, top_k=5): | |
| if baseline_collection.count() == 0: return [] | |
| query_emb = compute_baseline_embeddings([query]) | |
| if query_emb is None: return [] | |
| query_vec = query_emb.cpu().numpy().tolist()[0] | |
| results = baseline_collection.query(query_embeddings=[query_vec], n_results=min(top_k, baseline_collection.count()), include=["metadatas", "documents", "distances"]) | |
| output = [] | |
| if results['ids']: | |
| for i in range(len(results['ids'][0])): | |
| meta = results['metadatas'][0][i] | |
| code = results['documents'][0][i] | |
| dist = results['distances'][0][i] | |
| score = 1 - dist | |
| output.append([meta.get("file_name", "unknown"), f"{score:.4f}", code[:300] + "..."]) | |
| return output | |
| def search_finetuned(query, top_k=5): | |
| if finetuned_collection.count() == 0: return [] | |
| query_emb = compute_finetuned_embeddings([query]) | |
| if query_emb is None: return [] | |
| query_vec = query_emb.cpu().numpy().tolist()[0] | |
| results = finetuned_collection.query(query_embeddings=[query_vec], n_results=min(top_k, finetuned_collection.count()), include=["metadatas", "documents", "distances"]) | |
| output = [] | |
| if results['ids']: | |
| for i in range(len(results['ids'][0])): | |
| meta = results['metadatas'][0][i] | |
| code = results['documents'][0][i] | |
| dist = results['distances'][0][i] | |
| score = 1 - dist | |
| output.append([meta.get("file_name", "unknown"), f"{score:.4f}", code[:300] + "..."]) | |
| return output | |
| def search_comparison(query, top_k=5): | |
| baseline_results = search_baseline(query, top_k) | |
| finetuned_results = search_finetuned(query, top_k) | |
| return baseline_results, finetuned_results | |
| # --- Ingestion Functions --- | |
| def ingest_from_url(repo_url): | |
| if not repo_url.startswith("http"): | |
| yield "Invalid URL" | |
| return | |
| DATA_DIR = Path(os.path.abspath("data/raw_ingest")) | |
| import stat | |
| def remove_readonly(func, path, _): | |
| os.chmod(path, stat.S_IWRITE) | |
| func(path) | |
| try: | |
| if DATA_DIR.exists(): | |
| shutil.rmtree(DATA_DIR, onerror=remove_readonly) | |
| yield f"Cloning {repo_url}..." | |
| crawler = GitCrawler(cache_dir=DATA_DIR) | |
| repo_path = crawler.clone_repository(repo_url) | |
| if not repo_path: | |
| yield "Failed to clone repository." | |
| return | |
| yield "Listing files..." | |
| files = crawler.list_files(repo_path, extensions={'.py', '.md', '.json', '.js', '.ts', '.java', '.cpp'}) | |
| if isinstance(files, tuple): files = [f.path for f in files[0]] | |
| total_files = len(files) | |
| yield f"Found {total_files} files. Chunking..." | |
| chunker = RepoChunker() | |
| all_chunks = [] | |
| for i, file_path in enumerate(files): | |
| yield f"Chunking: {i+1}/{total_files} ({file_path.name})" | |
| try: | |
| meta = {"file_name": file_path.name, "url": repo_url} | |
| file_chunks = chunker.chunk_file(file_path, repo_metadata=meta) | |
| all_chunks.extend(file_chunks) | |
| except Exception as e: | |
| print(f"Skipping {file_path}: {e}") | |
| if not all_chunks: | |
| yield "No valid chunks found." | |
| return | |
| total_chunks = len(all_chunks) | |
| yield f"Generated {total_chunks} chunks. Embedding (BASELINE)..." | |
| batch_size = 64 | |
| # Index with baseline | |
| for i in range(0, total_chunks, batch_size): | |
| batch = all_chunks[i:i+batch_size] | |
| texts = [c.code for c in batch] | |
| ids = [str(uuid.uuid4()) for _ in batch] | |
| metadatas = [{"file_name": Path(c.file_path).name, "url": repo_url} for c in batch] | |
| embeddings = compute_baseline_embeddings(texts) | |
| if embeddings is not None: | |
| baseline_collection.add(ids=ids, embeddings=embeddings.cpu().numpy().tolist(), metadatas=metadatas, documents=texts) | |
| yield f"Baseline: {min(i+batch_size, total_chunks)}/{total_chunks}" | |
| yield f"Embedding (FINE-TUNED)..." | |
| # Index with fine-tuned | |
| for i in range(0, total_chunks, batch_size): | |
| batch = all_chunks[i:i+batch_size] | |
| texts = [c.code for c in batch] | |
| ids = [str(uuid.uuid4()) for _ in batch] | |
| metadatas = [{"file_name": Path(c.file_path).name, "url": repo_url} for c in batch] | |
| embeddings = compute_finetuned_embeddings(texts) | |
| if embeddings is not None: | |
| finetuned_collection.add(ids=ids, embeddings=embeddings.cpu().numpy().tolist(), metadatas=metadatas, documents=texts) | |
| yield f"Fine-tuned: {min(i+batch_size, total_chunks)}/{total_chunks}" | |
| yield f"SUCCESS! Indexed {total_chunks} chunks in both databases." | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| yield f"Error: {str(e)}" | |
| def ingest_from_files(files): | |
| if not files or len(files) == 0: | |
| yield "No files uploaded." | |
| return | |
| try: | |
| yield f"Processing {len(files)} file(s)..." | |
| chunker = RepoChunker() | |
| all_chunks = [] | |
| for i, file in enumerate(files): | |
| yield f"Chunking file {i+1}/{len(files)}: {Path(file.name).name}" | |
| try: | |
| # Gradio file upload: file.name contains the temp path | |
| file_path = Path(file.name) | |
| meta = {"file_name": file_path.name, "url": "uploaded"} | |
| file_chunks = chunker.chunk_file(file_path, repo_metadata=meta) | |
| all_chunks.extend(file_chunks) | |
| except Exception as e: | |
| yield f"Error chunking {Path(file.name).name}: {str(e)}" | |
| import traceback | |
| traceback.print_exc() | |
| if not all_chunks: | |
| yield "No valid chunks found." | |
| return | |
| total_chunks = len(all_chunks) | |
| yield f"Generated {total_chunks} chunks. Embedding (BASELINE)..." | |
| batch_size = 64 | |
| for i in range(0, total_chunks, batch_size): | |
| batch = all_chunks[i:i+batch_size] | |
| texts = [c.code for c in batch] | |
| ids = [str(uuid.uuid4()) for _ in batch] | |
| metadatas = [{"file_name": Path(c.file_path).name, "url": "uploaded"} for c in batch] | |
| embeddings = compute_baseline_embeddings(texts) | |
| if embeddings is not None: | |
| baseline_collection.add(ids=ids, embeddings=embeddings.cpu().numpy().tolist(), metadatas=metadatas, documents=texts) | |
| yield f"Baseline: {min(i+batch_size, total_chunks)}/{total_chunks}" | |
| yield f"Embedding (FINE-TUNED)..." | |
| for i in range(0, total_chunks, batch_size): | |
| batch = all_chunks[i:i+batch_size] | |
| texts = [c.code for c in batch] | |
| ids = [str(uuid.uuid4()) for _ in batch] | |
| metadatas = [{"file_name": Path(c.file_path).name, "url": "uploaded"} for c in batch] | |
| embeddings = compute_finetuned_embeddings(texts) | |
| if embeddings is not None: | |
| finetuned_collection.add(ids=ids, embeddings=embeddings.cpu().numpy().tolist(), metadatas=metadatas, documents=texts) | |
| yield f"Fine-tuned: {min(i+batch_size, total_chunks)}/{total_chunks}" | |
| yield f"SUCCESS! Indexed {total_chunks} chunks from uploaded files." | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| yield f"Error: {str(e)}" | |
| # --- Analysis & Evaluation Functions --- | |
| def analyze_embeddings_baseline(): | |
| count = baseline_collection.count() | |
| if count < 5: | |
| return "Not enough data (Need > 5 chunks).", None | |
| try: | |
| limit = min(count, 2000) | |
| data = baseline_collection.get(limit=limit, include=["embeddings", "metadatas"]) | |
| X = torch.tensor(data['embeddings']) | |
| X_mean = torch.mean(X, 0) | |
| X_centered = X - X_mean | |
| U, S, V = torch.pca_lowrank(X_centered, q=2) | |
| projected = torch.matmul(X_centered, V[:, :2]).numpy() | |
| indices = torch.randint(0, len(X), (min(100, len(X)),)) | |
| sample = X[indices] | |
| sim_matrix = torch.mm(sample, sample.t()) | |
| mask = ~torch.eye(len(sample), dtype=bool) | |
| avg_sim = sim_matrix[mask].mean().item() | |
| diversity_score = 1.0 - avg_sim | |
| metrics = ( | |
| f"BASELINE MODEL\n" | |
| f"Total Chunks: {count}\n" | |
| f"Analyzed: {len(X)}\n" | |
| f"Diversity Score: {diversity_score:.4f}\n" | |
| f"Avg Similarity: {avg_sim:.4f}" | |
| ) | |
| plot_df = pd.DataFrame({ | |
| "x": projected[:, 0], | |
| "y": projected[:, 1], | |
| "topic": [m.get("file_name", "unknown") for m in data['metadatas']] | |
| }) | |
| import matplotlib.pyplot as plt | |
| import io | |
| from PIL import Image | |
| # Create matplotlib figure with proper spacing | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| fig.subplots_adjust(top=0.92) # Add space for title | |
| # Plot each file with different color | |
| unique_topics = plot_df["topic"].unique() | |
| for topic in unique_topics: | |
| mask = plot_df["topic"] == topic | |
| ax.scatter(plot_df[mask]["x"], plot_df[mask]["y"], label=topic, alpha=0.6, s=50) | |
| ax.set_xlabel("PC1") | |
| ax.set_ylabel("PC2") | |
| ax.set_title("Baseline Semantic Space (PCA)", fontsize=14, pad=20) | |
| ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8) | |
| ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| # Convert to image for Gradio | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| plt.close() | |
| return metrics, img | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return f"Error: {e}", None | |
| def analyze_embeddings_finetuned(): | |
| count = finetuned_collection.count() | |
| if count < 5: | |
| return "Not enough data (Need > 5 chunks).", None | |
| try: | |
| limit = min(count, 2000) | |
| data = finetuned_collection.get(limit=limit, include=["embeddings", "metadatas"]) | |
| X = torch.tensor(data['embeddings']) | |
| X_mean = torch.mean(X, 0) | |
| X_centered = X - X_mean | |
| U, S, V = torch.pca_lowrank(X_centered, q=2) | |
| projected = torch.matmul(X_centered, V[:, :2]).numpy() | |
| indices = torch.randint(0, len(X), (min(100, len(X)),)) | |
| sample = X[indices] | |
| sim_matrix = torch.mm(sample, sample.t()) | |
| mask = ~torch.eye(len(sample), dtype=bool) | |
| avg_sim = sim_matrix[mask].mean().item() | |
| diversity_score = 1.0 - avg_sim | |
| metrics = ( | |
| f"FINE-TUNED MODEL\n" | |
| f"Total Chunks: {count}\n" | |
| f"Analyzed: {len(X)}\n" | |
| f"Diversity Score: {diversity_score:.4f}\n" | |
| f"Avg Similarity: {avg_sim:.4f}" | |
| ) | |
| plot_df = pd.DataFrame({ | |
| "x": projected[:, 0], | |
| "y": projected[:, 1], | |
| "topic": [m.get("file_name", "unknown") for m in data['metadatas']] | |
| }) | |
| import matplotlib.pyplot as plt | |
| import io | |
| from PIL import Image | |
| # Create matplotlib figure with proper spacing | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| fig.subplots_adjust(top=0.92) # Add space for title | |
| # Plot each file with different color | |
| unique_topics = plot_df["topic"].unique() | |
| for topic in unique_topics: | |
| mask = plot_df["topic"] == topic | |
| ax.scatter(plot_df[mask]["x"], plot_df[mask]["y"], label=topic, alpha=0.6, s=50) | |
| ax.set_xlabel("PC1") | |
| ax.set_ylabel("PC2") | |
| ax.set_title("Fine-tuned Semantic Space (PCA)", fontsize=14, pad=20) | |
| ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8) | |
| ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| # Convert to image for Gradio | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| plt.close() | |
| return metrics, img | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return f"Error: {e}", None | |
| def evaluate_retrieval_baseline(sample_limit): | |
| count = baseline_collection.count() | |
| if count < 10: return "Not enough data for evaluation (Need > 10 chunks)." | |
| try: | |
| fetch_limit = min(count, 2000) | |
| data = baseline_collection.get(limit=fetch_limit, include=["documents"]) | |
| import random | |
| actual_sample_size = min(sample_limit, len(data['ids'])) | |
| sample_indices = random.sample(range(len(data['ids'])), actual_sample_size) | |
| hits_at_1 = 0 | |
| hits_at_5 = 0 | |
| mrr_sum = 0 | |
| yield f"BASELINE: Evaluating {actual_sample_size} chunks..." | |
| for i, idx in enumerate(sample_indices): | |
| target_id = data['ids'][idx] | |
| code = data['documents'][idx] | |
| query = "\n".join(code.split("\n")[:3]) | |
| query_emb = compute_baseline_embeddings([query]).cpu().numpy().tolist()[0] | |
| results = baseline_collection.query(query_embeddings=[query_emb], n_results=10) | |
| found_ids = results['ids'][0] | |
| if target_id in found_ids: | |
| rank = found_ids.index(target_id) + 1 | |
| mrr_sum += 1.0 / rank | |
| if rank == 1: hits_at_1 += 1 | |
| if rank <= 5: hits_at_5 += 1 | |
| if i % 10 == 0: | |
| yield f"Baseline: {i}/{actual_sample_size}..." | |
| recall_1 = hits_at_1 / actual_sample_size | |
| recall_5 = hits_at_5 / actual_sample_size | |
| mrr = mrr_sum / actual_sample_size | |
| report = ( | |
| f"BASELINE EVALUATION ({actual_sample_size} chunks)\n" | |
| f"{'='*40}\n" | |
| f"Recall@1: {recall_1:.4f}\n" | |
| f"Recall@5: {recall_5:.4f}\n" | |
| f"MRR: {mrr:.4f}" | |
| ) | |
| yield report | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| yield f"Error: {e}" | |
| def evaluate_retrieval_finetuned(sample_limit): | |
| count = finetuned_collection.count() | |
| if count < 10: return "Not enough data for evaluation (Need > 10 chunks)." | |
| try: | |
| fetch_limit = min(count, 2000) | |
| data = finetuned_collection.get(limit=fetch_limit, include=["documents"]) | |
| import random | |
| actual_sample_size = min(sample_limit, len(data['ids'])) | |
| sample_indices = random.sample(range(len(data['ids'])), actual_sample_size) | |
| hits_at_1 = 0 | |
| hits_at_5 = 0 | |
| mrr_sum = 0 | |
| yield f"FINE-TUNED: Evaluating {actual_sample_size} chunks..." | |
| for i, idx in enumerate(sample_indices): | |
| target_id = data['ids'][idx] | |
| code = data['documents'][idx] | |
| query = "\n".join(code.split("\n")[:3]) | |
| query_emb = compute_finetuned_embeddings([query]).cpu().numpy().tolist()[0] | |
| results = finetuned_collection.query(query_embeddings=[query_emb], n_results=10) | |
| found_ids = results['ids'][0] | |
| if target_id in found_ids: | |
| rank = found_ids.index(target_id) + 1 | |
| mrr_sum += 1.0 / rank | |
| if rank == 1: hits_at_1 += 1 | |
| if rank <= 5: hits_at_5 += 1 | |
| if i % 10 == 0: | |
| yield f"Fine-tuned: {i}/{actual_sample_size}..." | |
| recall_1 = hits_at_1 / actual_sample_size | |
| recall_5 = hits_at_5 / actual_sample_size | |
| mrr = mrr_sum / actual_sample_size | |
| report = ( | |
| f"FINE-TUNED EVALUATION ({actual_sample_size} chunks)\n" | |
| f"{'='*40}\n" | |
| f"Recall@1: {recall_1:.4f}\n" | |
| f"Recall@5: {recall_5:.4f}\n" | |
| f"MRR: {mrr:.4f}" | |
| ) | |
| yield report | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| yield f"Error: {e}" | |
| # --- UI --- | |
| theme = gr.themes.Soft(primary_hue="slate", neutral_hue="slate", spacing_size="sm", radius_size="md").set(body_background_fill="*neutral_50", block_background_fill="white", block_border_width="1px", block_title_text_weight="600") | |
| css = """ | |
| h1 { text-align: center; font-family: 'Inter', sans-serif; margin-bottom: 1rem; color: #1e293b; } | |
| .gradio-container { max-width: 1400px !important; margin: auto; } | |
| .comparison-header { font-size: 1.1rem; font-weight: 600; color: #334155; text-align: center; padding: 0.5rem; } | |
| """ | |
| with gr.Blocks(theme=theme, css=css, title="CodeMode - Baseline vs Fine-tuned") as demo: | |
| gr.Markdown("# CodeMode: Baseline vs Fine-tuned Model Comparison") | |
| gr.Markdown("Compare retrieval performance between **microsoft/codebert-base** (baseline) and **MRL-enhanced fine-tuned** model") | |
| with gr.Tabs(): | |
| # TAB 1: INGEST | |
| with gr.Tab("1. Ingest Code"): | |
| with gr.Tabs(): | |
| with gr.Tab("GitHub Repository"): | |
| repo_input = gr.Textbox(label="GitHub URL", placeholder="https://github.com/pallets/flask") | |
| ingest_url_btn = gr.Button("Ingest from URL", variant="primary") | |
| url_status = gr.Textbox(label="Status") | |
| ingest_url_btn.click(ingest_from_url, inputs=repo_input, outputs=url_status) | |
| with gr.Tab("Upload Python Files"): | |
| file_upload = gr.File(label="Upload .py files", file_types=[".py"], file_count="multiple") | |
| ingest_files_btn = gr.Button("Ingest Uploaded Files", variant="primary") | |
| upload_status = gr.Textbox(label="Status") | |
| ingest_files_btn.click(ingest_from_files, inputs=file_upload, outputs=upload_status) | |
| with gr.Row(): | |
| reset_baseline_btn = gr.Button("Reset Baseline DB", variant="stop") | |
| reset_finetuned_btn = gr.Button("Reset Fine-tuned DB", variant="stop") | |
| reset_status = gr.Textbox(label="Reset Status") | |
| reset_baseline_btn.click(reset_baseline, inputs=[], outputs=reset_status) | |
| reset_finetuned_btn.click(reset_finetuned, inputs=[], outputs=reset_status) | |
| gr.Markdown("---") | |
| gr.Markdown("### Database Inspector") | |
| gr.Markdown("View indexed files in each collection") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### Baseline Collection") | |
| inspect_baseline_btn = gr.Button("Inspect Baseline DB", variant="secondary") | |
| baseline_files_df = gr.Dataframe( | |
| headers=["File Name", "Chunks", "Source URL"], | |
| datatype=["str", "number", "str"], | |
| interactive=False, | |
| value=[["No data yet", "-", "-"]] | |
| ) | |
| inspect_baseline_btn.click(list_baseline_files, inputs=[], outputs=baseline_files_df) | |
| with gr.Column(): | |
| gr.Markdown("#### Fine-tuned Collection") | |
| inspect_finetuned_btn = gr.Button("Inspect Fine-tuned DB", variant="secondary") | |
| finetuned_files_df = gr.Dataframe( | |
| headers=["File Name", "Chunks", "Source URL"], | |
| datatype=["str", "number", "str"], | |
| interactive=False, | |
| value=[["No data yet", "-", "-"]] | |
| ) | |
| inspect_finetuned_btn.click(list_finetuned_files, inputs=[], outputs=finetuned_files_df) | |
| # TAB 2: COMPARISON SEARCH | |
| with gr.Tab("2. Comparison Search (Note: Semantic search is sensitive to query phrasing)"): | |
| gr.Markdown("### Side-by-Side Retrieval Comparison") | |
| search_query = gr.Textbox(label="Search Query", placeholder="e.g., 'Flask route decorator'") | |
| compare_btn = gr.Button("Compare Models", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("<div class='comparison-header'>BASELINE (CodeBERT)</div>", elem_classes="comparison-header") | |
| baseline_results = gr.Dataframe(headers=["File", "Score", "Code Snippet"], datatype=["str", "str", "str"], interactive=False, wrap=True) | |
| with gr.Column(): | |
| gr.Markdown("<div class='comparison-header'>FINE-TUNED (MRL-Enhanced)</div>", elem_classes="comparison-header") | |
| finetuned_results = gr.Dataframe(headers=["File", "Score", "Code Snippet"], datatype=["str", "str", "str"], interactive=False, wrap=True) | |
| compare_btn.click(search_comparison, inputs=search_query, outputs=[baseline_results, finetuned_results]) | |
| # TAB 3: CODE SIMILARITY SEARCH | |
| with gr.Tab("3. Code Similarity Search"): | |
| gr.Markdown("### Find Similar Code Snippets") | |
| gr.Markdown("Paste a code snippet to find similar code in the database") | |
| with gr.Row(): | |
| with gr.Column(): | |
| code_input = gr.Code(label="Paste Code Snippet", language="python", lines=10) | |
| similarity_btn = gr.Button("Find Similar Code", variant="primary") | |
| with gr.Column(): | |
| gr.Markdown("#### Search Settings") | |
| top_k_slider = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Number of Results") | |
| model_choice = gr.Radio(["Baseline", "Fine-tuned", "Both"], value="Both", label="Model to Use") | |
| gr.Markdown("### Results") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### Baseline Results") | |
| baseline_code_results = gr.Dataframe( | |
| headers=["File", "Similarity", "Code Snippet"], | |
| datatype=["str", "str", "str"], | |
| interactive=False, | |
| wrap=True, | |
| value=[["No search yet", "-", "-"]] | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("#### Fine-tuned Results") | |
| finetuned_code_results = gr.Dataframe( | |
| headers=["File", "Similarity", "Code Snippet"], | |
| datatype=["str", "str", "str"], | |
| interactive=False, | |
| wrap=True, | |
| value=[["No search yet", "-", "-"]] | |
| ) | |
| def search_similar_code(code_snippet, top_k, model_choice): | |
| if not code_snippet or len(code_snippet.strip()) == 0: | |
| empty = [["Enter code to search", "-", "-"]] | |
| return empty, empty | |
| baseline_res = [] | |
| finetuned_res = [] | |
| if model_choice in ["Baseline", "Both"]: | |
| baseline_res = search_baseline(code_snippet, top_k) | |
| if not baseline_res: | |
| baseline_res = [["No results found", "-", "-"]] | |
| if model_choice in ["Fine-tuned", "Both"]: | |
| finetuned_res = search_finetuned(code_snippet, top_k) | |
| if not finetuned_res: | |
| finetuned_res = [["No results found", "-", "-"]] | |
| if model_choice == "Baseline": | |
| finetuned_res = [["Not searched", "-", "-"]] | |
| elif model_choice == "Fine-tuned": | |
| baseline_res = [["Not searched", "-", "-"]] | |
| return baseline_res, finetuned_res | |
| similarity_btn.click( | |
| search_similar_code, | |
| inputs=[code_input, top_k_slider, model_choice], | |
| outputs=[baseline_code_results, finetuned_code_results] | |
| ) | |
| # TAB 4: DEPLOYMENT MONITORING | |
| with gr.Tab("4. Deployment Monitoring"): | |
| gr.Markdown("### Embedding Quality Analysis") | |
| gr.Markdown("Analyze the semantic space distribution and diversity of embeddings") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### Baseline Model") | |
| analyze_baseline_btn = gr.Button("Analyze Baseline Embeddings", variant="secondary") | |
| baseline_metrics = gr.Textbox(label="Baseline Metrics") | |
| baseline_plot = gr.Image() | |
| analyze_baseline_btn.click(analyze_embeddings_baseline, inputs=[], outputs=[baseline_metrics, baseline_plot]) | |
| with gr.Column(): | |
| gr.Markdown("#### Fine-tuned Model") | |
| analyze_finetuned_btn = gr.Button("Analyze Fine-tuned Embeddings", variant="secondary") | |
| finetuned_metrics = gr.Textbox(label="Fine-tuned Metrics") | |
| finetuned_plot = gr.Image() | |
| analyze_finetuned_btn.click(analyze_embeddings_finetuned, inputs=[], outputs=[finetuned_metrics, finetuned_plot]) | |
| gr.Markdown("---") | |
| gr.Markdown("### Retrieval Performance Evaluation") | |
| gr.Markdown("Evaluate retrieval accuracy using synthetic queries (query = first 3 lines of code)") | |
| eval_size = gr.Slider(minimum=10, maximum=500, value=50, step=10, label="Sample Size (Chunks to Evaluate)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### Baseline Evaluation") | |
| eval_baseline_btn = gr.Button("Run Baseline Evaluation", variant="primary") | |
| baseline_eval_output = gr.Textbox(label="Baseline Results") | |
| eval_baseline_btn.click(evaluate_retrieval_baseline, inputs=[eval_size], outputs=baseline_eval_output) | |
| with gr.Column(): | |
| gr.Markdown("#### Fine-tuned Evaluation") | |
| eval_finetuned_btn = gr.Button("Run Fine-tuned Evaluation", variant="primary") | |
| finetuned_eval_output = gr.Textbox(label="Fine-tuned Results") | |
| eval_finetuned_btn.click(evaluate_retrieval_finetuned, inputs=[eval_size], outputs=finetuned_eval_output) | |
| if __name__ == "__main__": | |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=False) | |