Talip7's picture
Update app.py
b320624 verified
import torch
import gradio as gr
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# -----------------------------
# LOAD SEMANTIC DATASET
# -----------------------------
DATASET_ID = "Talip7/scikit-learn-issues-embeddings-mpnet"
train_ds = load_dataset(DATASET_ID, split="train")
train_ds = train_ds.add_faiss_index(column="embedding")
# -----------------------------
# LOAD EMBEDDING MODEL
# -----------------------------
EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
encoder = SentenceTransformer(
EMBEDDING_MODEL,
device="cuda" if torch.cuda.is_available() else "cpu"
)
# -----------------------------
# LOAD MULTILABEL CLASSIFIER
# -----------------------------
CLASSIFIER_ID = "Talip7/scikit-learn-multilabel-classifier"
tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_ID)
clf_model = AutoModelForSequenceClassification.from_pretrained(
CLASSIFIER_ID,
problem_type="multi_label_classification"
).to("cuda" if torch.cuda.is_available() else "cpu")
clf_model.eval()
# -----------------------------
# UTILS
# -----------------------------
def predict_labels(text, threshold=0.5):
inputs = tokenizer(
text,
truncation=True,
padding=True,
max_length=512,
return_tensors="pt"
)
inputs = {k: v.to(clf_model.device) for k, v in inputs.items()}
with torch.no_grad():
logits = clf_model(**inputs).logits
probs = torch.sigmoid(logits)[0].cpu().numpy()
labels = []
for i, p in enumerate(probs):
if p >= threshold:
labels.append(clf_model.config.id2label[i])
return labels
def semantic_search(query, k=15):
query_emb = encoder.encode(query, convert_to_numpy=True)
scores, samples = train_ds.get_nearest_examples(
"embedding",
query_emb,
k=k
)
return scores, samples
def hybrid_search(query, alpha=0.7, beta=0.3, max_results=5):
sem_scores, sem_results = semantic_search(query)
predicted_labels = set(predict_labels(query))
seen = set()
results = []
for i in range(len(sem_scores)):
issue_id = sem_results["issue_number"][i]
if issue_id in seen:
continue
seen.add(issue_id)
issue_labels = set(sem_results["labels"][i])
overlap = (
len(issue_labels & predicted_labels) / len(issue_labels)
if issue_labels else 0.0
)
final_score = alpha * float(sem_scores[i]) + beta * overlap
results.append({
"Issue": f"#{issue_id}",
"Final score": round(final_score, 3),
"Semantic": round(float(sem_scores[i]), 3),
"Label overlap": round(overlap, 2),
"Labels": ", ".join(issue_labels),
"URL": sem_results["html_url"][i],
})
if len(results) >= max_results:
break
return list(predicted_labels), results
def run_search(query):
if not query.strip():
return "Please enter a query.", ""
labels, results = hybrid_search(query)
label_text = ", ".join(labels) if labels else "No label confidently predicted"
html = """
<div style="width:100%; overflow-x:auto;">
<table style="width:100%; border-collapse:collapse;">
<thead>
<tr>
<th style="border:1px solid #555; padding:8px;">Issue</th>
<th style="border:1px solid #555; padding:8px; text-align:center;">Final</th>
<th style="border:1px solid #555; padding:8px; text-align:center;">Semantic</th>
<th style="border:1px solid #555; padding:8px; text-align:center;">Label overlap</th>
<th style="border:1px solid #555; padding:8px;">Labels</th>
<th style="border:1px solid #555; padding:8px;">GitHub URL</th>
</tr>
</thead>
<tbody>
"""
for r in results:
html += f"""
<tr>
<td style="border:1px solid #555; padding:8px;">{r['Issue']}</td>
<td style="border:1px solid #555; padding:8px; text-align:center;">{r['Final score']}</td>
<td style="border:1px solid #555; padding:8px; text-align:center;">{r['Semantic']}</td>
<td style="border:1px solid #555; padding:8px; text-align:center;">{r['Label overlap']}</td>
<td style="border:1px solid #555; padding:8px;">{r['Labels']}</td>
<td style="border:1px solid #555; padding:8px;">
<a href="{r['URL']}" target="_blank">{r['URL']}</a>
</td>
</tr>
"""
html += "</tbody></table></div>"
return label_text, html
# -----------------------------
# GRADIO UI
# -----------------------------
with gr.Blocks(
title="GitHub Issue Hybrid Search",
css="""
/* Sayfanın ana container'ını genişlet */
.gradio-container {
max-width: 100% !important;
padding-left: 40px;
padding-right: 40px;
}
/* Markdown tablo tam genişlik */
.gr-markdown table {
width: 100%;
table-layout: fixed;
}
.gr-markdown th, .gr-markdown td {
text-align: left;
padding: 8px;
word-wrap: break-word;
}
/* Sayısal kolonlar ortalansın */
.gr-markdown td:nth-child(2),
.gr-markdown td:nth-child(3),
.gr-markdown td:nth-child(4) {
text-align: center;
}
"""
) as demo:
gr.Markdown(
"""
# 🤗 GitHub Issue Hybrid Search & Auto-Label Assistant
**Semantic Search (MPNet) + Multilabel Classification (DistilBERT)**
Precision-first hybrid ranking on real scikit-learn issues.
"""
)
query = gr.Textbox(
label="Describe the issue",
placeholder="e.g. RandomForestClassifier crashes when sample_weight is None"
)
btn = gr.Button("Search")
predicted = gr.Textbox(label="Predicted labels", interactive=False)
results_html = gr.HTML()
btn.click(
fn=run_search,
inputs=query,
outputs=[predicted, results_html]
)
demo.launch()