stevenbucaille's picture
Add app.py and requirements.txt
63b131e
import gradio as gr
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import pandas as pd
import os
import tqdm
import uuid
# --- Configurations ---
DATASET_ID = "stevenbucaille/semantic-transformers"
COLLECTION_NAME = "transformers_code"
# We'll use a local directory for Qdrant storage within the Space
QDRANT_PATH = "./qdrant_data"
# Model must match the one used for embedding generation.
# Ideally we read this from the dataset metadata, but let's default to the popular one
# or make it configurable if we want. For now, let's assume valid model is available or use a default one
# that matches what was likely used (Snowflake/snowflake-arctic-embed-m).
# NOTE: The dataset should have "embedding" column. If so, we don't need the model for indexing, ONLY for query.
DEFAULT_MODEL = "Snowflake/snowflake-arctic-embed-m"
print("Initializing Qdrant Client...")
client = QdrantClient(path=QDRANT_PATH)
print("Loading Embedding Model for queries...")
model = SentenceTransformer(DEFAULT_MODEL, trust_remote_code=True)
def initialize_index(progress=gr.Progress()):
"""
Checks if collection exists. If not, loads dataset and indexes it.
"""
collections = client.get_collections().collections
exists = any(c.name == COLLECTION_NAME for c in collections)
if exists:
count = client.count(COLLECTION_NAME).count
if count > 0:
return f"Index already exists with {count} vectors. Ready."
# Needs indexing
progress(0.1, desc=f"Pulling dataset {DATASET_ID}...")
try:
ds = load_dataset(DATASET_ID, split="train")
df = ds.to_pandas()
except Exception as e:
return f"Error loading dataset: {e}"
if "embedding" not in df.columns:
return "Error: Dataset does not contain 'embedding' column."
# Remove rows with None embeddings
df = df.dropna(subset=["embedding"])
total_vectors = len(df)
# Create Collection
# Determine vector size from first element
sample_vec = df.iloc[0]["embedding"]
vec_size = len(sample_vec)
client.recreate_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=vec_size, distance=Distance.COSINE),
)
# Upload in batches
BATCH_SIZE = 500
points = []
progress(0.2, desc="Indexing vectors...")
for idx, row in tqdm.tqdm(df.iterrows(), total=total_vectors):
# Create metadata dict (exclude embedding)
payload = row.drop("embedding").to_dict()
# Point ID: use a UUID based on index or just simple integer index if unique
point_id = idx
points.append(
PointStruct(id=point_id, vector=row["embedding"], payload=payload)
)
if len(points) >= BATCH_SIZE:
client.upsert(collection_name=COLLECTION_NAME, points=points)
points = []
if idx % 5000 == 0:
progress(
0.2 + 0.8 * (idx / total_vectors),
desc=f"Indexed {idx}/{total_vectors}...",
)
# Final batch
if points:
client.upsert(collection_name=COLLECTION_NAME, points=points)
return f"Successfully indexed {total_vectors} chunks."
def search_code(query, limit=5):
"""
Embeds query and searches Qdrant.
"""
if not query.strip():
return "Please enter a query."
# Embed query
query_vector = model.encode(query)
hits = client.search(
collection_name=COLLECTION_NAME, query_vector=query_vector, limit=limit
)
results = []
for hit in hits:
score = hit.score
payload = hit.payload
file_path = payload.get("file_path", "Unknown")
name = payload.get("name", "Unknown")
lines = f"{payload.get('start_line')}-{payload.get('end_line')}"
code = payload.get("content", "")
results.append((score, file_path, name, lines, code))
return results
def format_search_results(results):
if isinstance(results, str):
return results # Error message
html = ""
for score, fpath, name, lines, code in results:
html += f"""
<div style="border: 1px solid #ddd; padding: 10px; margin-bottom: 10px; border-radius: 5px;">
<div style="display: flex; justify-content: space-between; background-color: #f7f7f7; padding: 5px; border-radius: 3px;">
<strong>{os.path.basename(fpath)} :: {name} (Lines {lines})</strong>
<span style="color: green;">Score: {score:.4f}</span>
</div>
<pre style="background-color: #f0f0f0; padding: 10px; overflow-x: auto; margin-top: 5px;"><code>{code}</code></pre>
<div style="font-size: 0.8em; color: gray;">Path: {fpath}</div>
</div>
"""
return html
def scan_refactoring_candidates(
threshold=0.95, sample_size=100, progress=gr.Progress()
):
"""
Scans for near-duplicates.
Since O(N^2) is too slow, we sample N random points and query for neighbors.
"""
# Get total count
count = client.count(COLLECTION_NAME).count
# We can't easily sample random points efficiently in Qdrant Local without scrolling,
# but we can just scroll a few pages.
candidates = []
progress(0.1, desc="Scanning for duplicates...")
# Scroll through data
next_offset = None
scanned = 0
# We limit the scan to 'sample_size' items to be responsive
# In a real app we might run a background job
# Fetch batch of points
scroll_result = client.scroll(
collection_name=COLLECTION_NAME, limit=sample_size, with_vectors=True
)
points = scroll_result[0]
for i, point in enumerate(points):
# Query for neighbors
hits = client.search(
collection_name=COLLECTION_NAME,
query_vector=point.vector,
limit=5, # Top 5 nearest
score_threshold=threshold,
)
# Filter self
neighbors = [h for h in hits if h.id != point.id]
if neighbors:
# We found a duplicate group
base_info = {
"file": point.payload.get("file_name"),
"name": point.payload.get("name"),
"code": point.payload.get("content"),
}
matches = []
for n in neighbors:
matches.append(
{
"file": n.payload.get("file_name"),
"name": n.payload.get("name"),
"score": n.score,
"code": n.payload.get("content"),
}
)
candidates.append({"base": base_info, "matches": matches})
progress((i / len(points)), desc="Scanning...")
return candidates
def format_refactoring_results(candidates):
if not candidates:
return "No significant duplicates found in the sample."
html = f"<h3>Found {len(candidates)} potential refactoring groups</h3>"
for group in candidates:
base = group["base"]
html += f"""
<div style="border: 2px solid #ffaa00; padding: 10px; margin-bottom: 20px; border-radius: 5px;">
<h4>Source: {base["file"]} - {base["name"]}</h4>
<div style="max-height: 150px; overflow-y: auto;"><pre>{base["code"]}</pre></div>
<hr>
<h5>Similar Candidates:</h5>
"""
for m in group["matches"]:
html += f"""
<div style="margin-left: 20px; border-left: 3px solid #ddd; padding-left: 10px;">
<strong>{m["file"]} - {m["name"]}</strong> (Score: {m["score"]:.4f})
<div style="max-height: 100px; overflow-y: auto; background:#fafafa;"><pre>{m["code"]}</pre></div>
</div>
"""
html += "</div>"
return html
# --- UI Interface ---
with gr.Blocks(title="Semantic Code Search") as demo:
gr.Markdown("# 🔍 Semantic Code Search & Refactoring Tool")
with gr.Accordion("System Status", open=True):
status_output = gr.Textbox(
label="Index Status", value="Checking index...", interactive=False
)
init_btn = gr.Button("Initialize / Rebuild Index")
with gr.Tabs():
# TAB 1: SEARCH
with gr.TabItem("Code Search"):
with gr.Row():
with gr.Column():
query_input = gr.Code(
language="python",
label="Paste Code Snippet or Natural Language Query",
)
search_btn = gr.Button("Search", variant="primary")
limit_slider = gr.Slider(
minimum=1, maximum=20, value=5, step=1, label="Max Results"
)
with gr.Column():
results_output = gr.HTML(label="Results")
search_btn.click(
fn=lambda q, l: format_search_results(search_code(q, l)),
inputs=[query_input, limit_slider],
outputs=results_output,
)
# TAB 2: REFACTORING
with gr.TabItem("Refactoring Inspector"):
gr.Markdown(
"Scans a random sample of the codebase to find functions that are highly similar (possible duplicates)."
)
with gr.Row():
scan_btn = gr.Button("Scan for Duplicates", variant="secondary")
threshold_slider = gr.Slider(
minimum=0.80,
maximum=0.99,
value=0.95,
step=0.01,
label="Similarity Threshold",
)
sample_slider = gr.Slider(
minimum=10,
maximum=500,
value=50,
step=10,
label="Sample Size (Chunks)",
)
refactor_output = gr.HTML(label="Candidates")
scan_btn.click(
fn=lambda t, s: format_refactoring_results(
scan_refactoring_candidates(t, s)
),
inputs=[threshold_slider, sample_slider],
outputs=refactor_output,
)
# Init event
demo.load(initialize_index, outputs=status_output)
init_btn.click(initialize_index, outputs=status_output)
if __name__ == "__main__":
demo.launch()