import torch import gradio as gr from datasets import load_dataset from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForSequenceClassification # ----------------------------- # LOAD SEMANTIC DATASET # ----------------------------- DATASET_ID = "Talip7/scikit-learn-issues-embeddings-mpnet" train_ds = load_dataset(DATASET_ID, split="train") train_ds = train_ds.add_faiss_index(column="embedding") # ----------------------------- # LOAD EMBEDDING MODEL # ----------------------------- EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2" encoder = SentenceTransformer( EMBEDDING_MODEL, device="cuda" if torch.cuda.is_available() else "cpu" ) # ----------------------------- # LOAD MULTILABEL CLASSIFIER # ----------------------------- CLASSIFIER_ID = "Talip7/scikit-learn-multilabel-classifier" tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_ID) clf_model = AutoModelForSequenceClassification.from_pretrained( CLASSIFIER_ID, problem_type="multi_label_classification" ).to("cuda" if torch.cuda.is_available() else "cpu") clf_model.eval() # ----------------------------- # UTILS # ----------------------------- def predict_labels(text, threshold=0.5): inputs = tokenizer( text, truncation=True, padding=True, max_length=512, return_tensors="pt" ) inputs = {k: v.to(clf_model.device) for k, v in inputs.items()} with torch.no_grad(): logits = clf_model(**inputs).logits probs = torch.sigmoid(logits)[0].cpu().numpy() labels = [] for i, p in enumerate(probs): if p >= threshold: labels.append(clf_model.config.id2label[i]) return labels def semantic_search(query, k=15): query_emb = encoder.encode(query, convert_to_numpy=True) scores, samples = train_ds.get_nearest_examples( "embedding", query_emb, k=k ) return scores, samples def hybrid_search(query, alpha=0.7, beta=0.3, max_results=5): sem_scores, sem_results = semantic_search(query) predicted_labels = set(predict_labels(query)) seen = set() results = [] for i in range(len(sem_scores)): issue_id = sem_results["issue_number"][i] if issue_id in seen: continue seen.add(issue_id) issue_labels = set(sem_results["labels"][i]) overlap = ( len(issue_labels & predicted_labels) / len(issue_labels) if issue_labels else 0.0 ) final_score = alpha * float(sem_scores[i]) + beta * overlap results.append({ "Issue": f"#{issue_id}", "Final score": round(final_score, 3), "Semantic": round(float(sem_scores[i]), 3), "Label overlap": round(overlap, 2), "Labels": ", ".join(issue_labels), "URL": sem_results["html_url"][i], }) if len(results) >= max_results: break return list(predicted_labels), results def run_search(query): if not query.strip(): return "Please enter a query.", "" labels, results = hybrid_search(query) label_text = ", ".join(labels) if labels else "No label confidently predicted" html = """
| Issue | Final | Semantic | Label overlap | Labels | GitHub URL |
|---|---|---|---|---|---|
| {r['Issue']} | {r['Final score']} | {r['Semantic']} | {r['Label overlap']} | {r['Labels']} | {r['URL']} |