Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| # ----------------------------- | |
| # LOAD SEMANTIC DATASET | |
| # ----------------------------- | |
| DATASET_ID = "Talip7/scikit-learn-issues-embeddings-mpnet" | |
| train_ds = load_dataset(DATASET_ID, split="train") | |
| train_ds = train_ds.add_faiss_index(column="embedding") | |
| # ----------------------------- | |
| # LOAD EMBEDDING MODEL | |
| # ----------------------------- | |
| EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2" | |
| encoder = SentenceTransformer( | |
| EMBEDDING_MODEL, | |
| device="cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| # ----------------------------- | |
| # LOAD MULTILABEL CLASSIFIER | |
| # ----------------------------- | |
| CLASSIFIER_ID = "Talip7/scikit-learn-multilabel-classifier" | |
| tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_ID) | |
| clf_model = AutoModelForSequenceClassification.from_pretrained( | |
| CLASSIFIER_ID, | |
| problem_type="multi_label_classification" | |
| ).to("cuda" if torch.cuda.is_available() else "cpu") | |
| clf_model.eval() | |
| # ----------------------------- | |
| # UTILS | |
| # ----------------------------- | |
| def predict_labels(text, threshold=0.5): | |
| inputs = tokenizer( | |
| text, | |
| truncation=True, | |
| padding=True, | |
| max_length=512, | |
| return_tensors="pt" | |
| ) | |
| inputs = {k: v.to(clf_model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| logits = clf_model(**inputs).logits | |
| probs = torch.sigmoid(logits)[0].cpu().numpy() | |
| labels = [] | |
| for i, p in enumerate(probs): | |
| if p >= threshold: | |
| labels.append(clf_model.config.id2label[i]) | |
| return labels | |
| def semantic_search(query, k=15): | |
| query_emb = encoder.encode(query, convert_to_numpy=True) | |
| scores, samples = train_ds.get_nearest_examples( | |
| "embedding", | |
| query_emb, | |
| k=k | |
| ) | |
| return scores, samples | |
| def hybrid_search(query, alpha=0.7, beta=0.3, max_results=5): | |
| sem_scores, sem_results = semantic_search(query) | |
| predicted_labels = set(predict_labels(query)) | |
| seen = set() | |
| results = [] | |
| for i in range(len(sem_scores)): | |
| issue_id = sem_results["issue_number"][i] | |
| if issue_id in seen: | |
| continue | |
| seen.add(issue_id) | |
| issue_labels = set(sem_results["labels"][i]) | |
| overlap = ( | |
| len(issue_labels & predicted_labels) / len(issue_labels) | |
| if issue_labels else 0.0 | |
| ) | |
| final_score = alpha * float(sem_scores[i]) + beta * overlap | |
| results.append({ | |
| "Issue": f"#{issue_id}", | |
| "Final score": round(final_score, 3), | |
| "Semantic": round(float(sem_scores[i]), 3), | |
| "Label overlap": round(overlap, 2), | |
| "Labels": ", ".join(issue_labels), | |
| "URL": sem_results["html_url"][i], | |
| }) | |
| if len(results) >= max_results: | |
| break | |
| return list(predicted_labels), results | |
| def run_search(query): | |
| if not query.strip(): | |
| return "Please enter a query.", "" | |
| labels, results = hybrid_search(query) | |
| label_text = ", ".join(labels) if labels else "No label confidently predicted" | |
| html = """ | |
| <div style="width:100%; overflow-x:auto;"> | |
| <table style="width:100%; border-collapse:collapse;"> | |
| <thead> | |
| <tr> | |
| <th style="border:1px solid #555; padding:8px;">Issue</th> | |
| <th style="border:1px solid #555; padding:8px; text-align:center;">Final</th> | |
| <th style="border:1px solid #555; padding:8px; text-align:center;">Semantic</th> | |
| <th style="border:1px solid #555; padding:8px; text-align:center;">Label overlap</th> | |
| <th style="border:1px solid #555; padding:8px;">Labels</th> | |
| <th style="border:1px solid #555; padding:8px;">GitHub URL</th> | |
| </tr> | |
| </thead> | |
| <tbody> | |
| """ | |
| for r in results: | |
| html += f""" | |
| <tr> | |
| <td style="border:1px solid #555; padding:8px;">{r['Issue']}</td> | |
| <td style="border:1px solid #555; padding:8px; text-align:center;">{r['Final score']}</td> | |
| <td style="border:1px solid #555; padding:8px; text-align:center;">{r['Semantic']}</td> | |
| <td style="border:1px solid #555; padding:8px; text-align:center;">{r['Label overlap']}</td> | |
| <td style="border:1px solid #555; padding:8px;">{r['Labels']}</td> | |
| <td style="border:1px solid #555; padding:8px;"> | |
| <a href="{r['URL']}" target="_blank">{r['URL']}</a> | |
| </td> | |
| </tr> | |
| """ | |
| html += "</tbody></table></div>" | |
| return label_text, html | |
| # ----------------------------- | |
| # GRADIO UI | |
| # ----------------------------- | |
| with gr.Blocks( | |
| title="GitHub Issue Hybrid Search", | |
| css=""" | |
| /* Sayfanın ana container'ını genişlet */ | |
| .gradio-container { | |
| max-width: 100% !important; | |
| padding-left: 40px; | |
| padding-right: 40px; | |
| } | |
| /* Markdown tablo tam genişlik */ | |
| .gr-markdown table { | |
| width: 100%; | |
| table-layout: fixed; | |
| } | |
| .gr-markdown th, .gr-markdown td { | |
| text-align: left; | |
| padding: 8px; | |
| word-wrap: break-word; | |
| } | |
| /* Sayısal kolonlar ortalansın */ | |
| .gr-markdown td:nth-child(2), | |
| .gr-markdown td:nth-child(3), | |
| .gr-markdown td:nth-child(4) { | |
| text-align: center; | |
| } | |
| """ | |
| ) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🤗 GitHub Issue Hybrid Search & Auto-Label Assistant | |
| **Semantic Search (MPNet) + Multilabel Classification (DistilBERT)** | |
| Precision-first hybrid ranking on real scikit-learn issues. | |
| """ | |
| ) | |
| query = gr.Textbox( | |
| label="Describe the issue", | |
| placeholder="e.g. RandomForestClassifier crashes when sample_weight is None" | |
| ) | |
| btn = gr.Button("Search") | |
| predicted = gr.Textbox(label="Predicted labels", interactive=False) | |
| results_html = gr.HTML() | |
| btn.click( | |
| fn=run_search, | |
| inputs=query, | |
| outputs=[predicted, results_html] | |
| ) | |
| demo.launch() | |