| import gradio as gr |
| from qdrant_client import QdrantClient |
| from qdrant_client.models import Distance, VectorParams, PointStruct |
| from sentence_transformers import SentenceTransformer |
| from datasets import load_dataset |
| import pandas as pd |
| import os |
| import tqdm |
| import uuid |
|
|
| |
| DATASET_ID = "stevenbucaille/semantic-transformers" |
| COLLECTION_NAME = "transformers_code" |
| |
| QDRANT_PATH = "./qdrant_data" |
| |
| |
| |
| |
| |
| DEFAULT_MODEL = "Snowflake/snowflake-arctic-embed-m" |
|
|
| print("Initializing Qdrant Client...") |
| client = QdrantClient(path=QDRANT_PATH) |
|
|
| print("Loading Embedding Model for queries...") |
| model = SentenceTransformer(DEFAULT_MODEL, trust_remote_code=True) |
|
|
|
|
| def initialize_index(progress=gr.Progress()): |
| """ |
| Checks if collection exists. If not, loads dataset and indexes it. |
| """ |
| collections = client.get_collections().collections |
| exists = any(c.name == COLLECTION_NAME for c in collections) |
|
|
| if exists: |
| count = client.count(COLLECTION_NAME).count |
| if count > 0: |
| return f"Index already exists with {count} vectors. Ready." |
|
|
| |
| progress(0.1, desc=f"Pulling dataset {DATASET_ID}...") |
| try: |
| ds = load_dataset(DATASET_ID, split="train") |
| df = ds.to_pandas() |
| except Exception as e: |
| return f"Error loading dataset: {e}" |
|
|
| if "embedding" not in df.columns: |
| return "Error: Dataset does not contain 'embedding' column." |
|
|
| |
| df = df.dropna(subset=["embedding"]) |
| total_vectors = len(df) |
|
|
| |
| |
| sample_vec = df.iloc[0]["embedding"] |
| vec_size = len(sample_vec) |
|
|
| client.recreate_collection( |
| collection_name=COLLECTION_NAME, |
| vectors_config=VectorParams(size=vec_size, distance=Distance.COSINE), |
| ) |
|
|
| |
| BATCH_SIZE = 500 |
| points = [] |
|
|
| progress(0.2, desc="Indexing vectors...") |
|
|
| for idx, row in tqdm.tqdm(df.iterrows(), total=total_vectors): |
| |
| payload = row.drop("embedding").to_dict() |
|
|
| |
| point_id = idx |
|
|
| points.append( |
| PointStruct(id=point_id, vector=row["embedding"], payload=payload) |
| ) |
|
|
| if len(points) >= BATCH_SIZE: |
| client.upsert(collection_name=COLLECTION_NAME, points=points) |
| points = [] |
|
|
| if idx % 5000 == 0: |
| progress( |
| 0.2 + 0.8 * (idx / total_vectors), |
| desc=f"Indexed {idx}/{total_vectors}...", |
| ) |
|
|
| |
| if points: |
| client.upsert(collection_name=COLLECTION_NAME, points=points) |
|
|
| return f"Successfully indexed {total_vectors} chunks." |
|
|
|
|
| def search_code(query, limit=5): |
| """ |
| Embeds query and searches Qdrant. |
| """ |
| if not query.strip(): |
| return "Please enter a query." |
|
|
| |
| query_vector = model.encode(query) |
|
|
| hits = client.search( |
| collection_name=COLLECTION_NAME, query_vector=query_vector, limit=limit |
| ) |
|
|
| results = [] |
| for hit in hits: |
| score = hit.score |
| payload = hit.payload |
| file_path = payload.get("file_path", "Unknown") |
| name = payload.get("name", "Unknown") |
| lines = f"{payload.get('start_line')}-{payload.get('end_line')}" |
| code = payload.get("content", "") |
|
|
| results.append((score, file_path, name, lines, code)) |
|
|
| return results |
|
|
|
|
| def format_search_results(results): |
| if isinstance(results, str): |
| return results |
|
|
| html = "" |
| for score, fpath, name, lines, code in results: |
| html += f""" |
| <div style="border: 1px solid #ddd; padding: 10px; margin-bottom: 10px; border-radius: 5px;"> |
| <div style="display: flex; justify-content: space-between; background-color: #f7f7f7; padding: 5px; border-radius: 3px;"> |
| <strong>{os.path.basename(fpath)} :: {name} (Lines {lines})</strong> |
| <span style="color: green;">Score: {score:.4f}</span> |
| </div> |
| <pre style="background-color: #f0f0f0; padding: 10px; overflow-x: auto; margin-top: 5px;"><code>{code}</code></pre> |
| <div style="font-size: 0.8em; color: gray;">Path: {fpath}</div> |
| </div> |
| """ |
| return html |
|
|
|
|
| def scan_refactoring_candidates( |
| threshold=0.95, sample_size=100, progress=gr.Progress() |
| ): |
| """ |
| Scans for near-duplicates. |
| Since O(N^2) is too slow, we sample N random points and query for neighbors. |
| """ |
| |
| count = client.count(COLLECTION_NAME).count |
|
|
| |
| |
|
|
| candidates = [] |
|
|
| progress(0.1, desc="Scanning for duplicates...") |
|
|
| |
| next_offset = None |
| scanned = 0 |
|
|
| |
| |
|
|
| |
| scroll_result = client.scroll( |
| collection_name=COLLECTION_NAME, limit=sample_size, with_vectors=True |
| ) |
| points = scroll_result[0] |
|
|
| for i, point in enumerate(points): |
| |
| hits = client.search( |
| collection_name=COLLECTION_NAME, |
| query_vector=point.vector, |
| limit=5, |
| score_threshold=threshold, |
| ) |
|
|
| |
| neighbors = [h for h in hits if h.id != point.id] |
|
|
| if neighbors: |
| |
| base_info = { |
| "file": point.payload.get("file_name"), |
| "name": point.payload.get("name"), |
| "code": point.payload.get("content"), |
| } |
| matches = [] |
| for n in neighbors: |
| matches.append( |
| { |
| "file": n.payload.get("file_name"), |
| "name": n.payload.get("name"), |
| "score": n.score, |
| "code": n.payload.get("content"), |
| } |
| ) |
|
|
| candidates.append({"base": base_info, "matches": matches}) |
|
|
| progress((i / len(points)), desc="Scanning...") |
|
|
| return candidates |
|
|
|
|
| def format_refactoring_results(candidates): |
| if not candidates: |
| return "No significant duplicates found in the sample." |
|
|
| html = f"<h3>Found {len(candidates)} potential refactoring groups</h3>" |
|
|
| for group in candidates: |
| base = group["base"] |
| html += f""" |
| <div style="border: 2px solid #ffaa00; padding: 10px; margin-bottom: 20px; border-radius: 5px;"> |
| <h4>Source: {base["file"]} - {base["name"]}</h4> |
| <div style="max-height: 150px; overflow-y: auto;"><pre>{base["code"]}</pre></div> |
| <hr> |
| <h5>Similar Candidates:</h5> |
| """ |
| for m in group["matches"]: |
| html += f""" |
| <div style="margin-left: 20px; border-left: 3px solid #ddd; padding-left: 10px;"> |
| <strong>{m["file"]} - {m["name"]}</strong> (Score: {m["score"]:.4f}) |
| <div style="max-height: 100px; overflow-y: auto; background:#fafafa;"><pre>{m["code"]}</pre></div> |
| </div> |
| """ |
| html += "</div>" |
|
|
| return html |
|
|
|
|
| |
|
|
| with gr.Blocks(title="Semantic Code Search") as demo: |
| gr.Markdown("# 🔍 Semantic Code Search & Refactoring Tool") |
|
|
| with gr.Accordion("System Status", open=True): |
| status_output = gr.Textbox( |
| label="Index Status", value="Checking index...", interactive=False |
| ) |
| init_btn = gr.Button("Initialize / Rebuild Index") |
|
|
| with gr.Tabs(): |
| |
| with gr.TabItem("Code Search"): |
| with gr.Row(): |
| with gr.Column(): |
| query_input = gr.Code( |
| language="python", |
| label="Paste Code Snippet or Natural Language Query", |
| ) |
| search_btn = gr.Button("Search", variant="primary") |
| limit_slider = gr.Slider( |
| minimum=1, maximum=20, value=5, step=1, label="Max Results" |
| ) |
|
|
| with gr.Column(): |
| results_output = gr.HTML(label="Results") |
|
|
| search_btn.click( |
| fn=lambda q, l: format_search_results(search_code(q, l)), |
| inputs=[query_input, limit_slider], |
| outputs=results_output, |
| ) |
|
|
| |
| with gr.TabItem("Refactoring Inspector"): |
| gr.Markdown( |
| "Scans a random sample of the codebase to find functions that are highly similar (possible duplicates)." |
| ) |
| with gr.Row(): |
| scan_btn = gr.Button("Scan for Duplicates", variant="secondary") |
| threshold_slider = gr.Slider( |
| minimum=0.80, |
| maximum=0.99, |
| value=0.95, |
| step=0.01, |
| label="Similarity Threshold", |
| ) |
| sample_slider = gr.Slider( |
| minimum=10, |
| maximum=500, |
| value=50, |
| step=10, |
| label="Sample Size (Chunks)", |
| ) |
|
|
| refactor_output = gr.HTML(label="Candidates") |
|
|
| scan_btn.click( |
| fn=lambda t, s: format_refactoring_results( |
| scan_refactoring_candidates(t, s) |
| ), |
| inputs=[threshold_slider, sample_slider], |
| outputs=refactor_output, |
| ) |
|
|
| |
| demo.load(initialize_index, outputs=status_output) |
| init_btn.click(initialize_index, outputs=status_output) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|