import torch import gradio as gr from datasets import load_dataset from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForSequenceClassification # ----------------------------- # LOAD SEMANTIC DATASET # ----------------------------- DATASET_ID = "Talip7/scikit-learn-issues-embeddings-mpnet" train_ds = load_dataset(DATASET_ID, split="train") train_ds = train_ds.add_faiss_index(column="embedding") # ----------------------------- # LOAD EMBEDDING MODEL # ----------------------------- EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2" encoder = SentenceTransformer( EMBEDDING_MODEL, device="cuda" if torch.cuda.is_available() else "cpu" ) # ----------------------------- # LOAD MULTILABEL CLASSIFIER # ----------------------------- CLASSIFIER_ID = "Talip7/scikit-learn-multilabel-classifier" tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_ID) clf_model = AutoModelForSequenceClassification.from_pretrained( CLASSIFIER_ID, problem_type="multi_label_classification" ).to("cuda" if torch.cuda.is_available() else "cpu") clf_model.eval() # ----------------------------- # UTILS # ----------------------------- def predict_labels(text, threshold=0.5): inputs = tokenizer( text, truncation=True, padding=True, max_length=512, return_tensors="pt" ) inputs = {k: v.to(clf_model.device) for k, v in inputs.items()} with torch.no_grad(): logits = clf_model(**inputs).logits probs = torch.sigmoid(logits)[0].cpu().numpy() labels = [] for i, p in enumerate(probs): if p >= threshold: labels.append(clf_model.config.id2label[i]) return labels def semantic_search(query, k=15): query_emb = encoder.encode(query, convert_to_numpy=True) scores, samples = train_ds.get_nearest_examples( "embedding", query_emb, k=k ) return scores, samples def hybrid_search(query, alpha=0.7, beta=0.3, max_results=5): sem_scores, sem_results = semantic_search(query) predicted_labels = set(predict_labels(query)) seen = set() results = [] for i in range(len(sem_scores)): issue_id = sem_results["issue_number"][i] if issue_id in seen: continue seen.add(issue_id) issue_labels = set(sem_results["labels"][i]) overlap = ( len(issue_labels & predicted_labels) / len(issue_labels) if issue_labels else 0.0 ) final_score = alpha * float(sem_scores[i]) + beta * overlap results.append({ "Issue": f"#{issue_id}", "Final score": round(final_score, 3), "Semantic": round(float(sem_scores[i]), 3), "Label overlap": round(overlap, 2), "Labels": ", ".join(issue_labels), "URL": sem_results["html_url"][i], }) if len(results) >= max_results: break return list(predicted_labels), results def run_search(query): if not query.strip(): return "Please enter a query.", "" labels, results = hybrid_search(query) label_text = ", ".join(labels) if labels else "No label confidently predicted" html = """
""" for r in results: html += f""" """ html += "
Issue Final Semantic Label overlap Labels GitHub URL
{r['Issue']} {r['Final score']} {r['Semantic']} {r['Label overlap']} {r['Labels']} {r['URL']}
" return label_text, html # ----------------------------- # GRADIO UI # ----------------------------- with gr.Blocks( title="GitHub Issue Hybrid Search", css=""" /* Sayfanın ana container'ını genişlet */ .gradio-container { max-width: 100% !important; padding-left: 40px; padding-right: 40px; } /* Markdown tablo tam genişlik */ .gr-markdown table { width: 100%; table-layout: fixed; } .gr-markdown th, .gr-markdown td { text-align: left; padding: 8px; word-wrap: break-word; } /* Sayısal kolonlar ortalansın */ .gr-markdown td:nth-child(2), .gr-markdown td:nth-child(3), .gr-markdown td:nth-child(4) { text-align: center; } """ ) as demo: gr.Markdown( """ # 🤗 GitHub Issue Hybrid Search & Auto-Label Assistant **Semantic Search (MPNet) + Multilabel Classification (DistilBERT)** Precision-first hybrid ranking on real scikit-learn issues. """ ) query = gr.Textbox( label="Describe the issue", placeholder="e.g. RandomForestClassifier crashes when sample_weight is None" ) btn = gr.Button("Search") predicted = gr.Textbox(label="Predicted labels", interactive=False) results_html = gr.HTML() btn.click( fn=run_search, inputs=query, outputs=[predicted, results_html] ) demo.launch()