Talip7 commited on
Commit
5cb25d0
·
verified ·
1 Parent(s): 6ca3b6b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from datasets import load_dataset
4
+ from sentence_transformers import SentenceTransformer
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
+
7
+ # -----------------------------
8
+ # LOAD SEMANTIC DATASET
9
+ # -----------------------------
10
+ DATASET_ID = "Talip7/scikit-learn-issues-embeddings-mpnet"
11
+
12
+ train_ds = load_dataset(DATASET_ID, split="train")
13
+ train_ds = train_ds.add_faiss_index(column="embedding")
14
+
15
+ # -----------------------------
16
+ # LOAD EMBEDDING MODEL
17
+ # -----------------------------
18
+ EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
19
+
20
+ encoder = SentenceTransformer(
21
+ EMBEDDING_MODEL,
22
+ device="cuda" if torch.cuda.is_available() else "cpu"
23
+ )
24
+
25
+ # -----------------------------
26
+ # LOAD MULTILABEL CLASSIFIER
27
+ # -----------------------------
28
+ CLASSIFIER_ID = "Talip7/scikit-learn-multilabel-classifier"
29
+
30
+ tokenizer = AutoTokenizer.from_pretrained(CLASSIFIER_ID)
31
+ clf_model = AutoModelForSequenceClassification.from_pretrained(
32
+ CLASSIFIER_ID,
33
+ problem_type="multi_label_classification"
34
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
35
+
36
+ clf_model.eval()
37
+
38
+ # -----------------------------
39
+ # UTILS
40
+ # -----------------------------
41
+ def predict_labels(text, threshold=0.5):
42
+ inputs = tokenizer(
43
+ text,
44
+ truncation=True,
45
+ padding=True,
46
+ max_length=512,
47
+ return_tensors="pt"
48
+ )
49
+ inputs = {k: v.to(clf_model.device) for k, v in inputs.items()}
50
+
51
+ with torch.no_grad():
52
+ logits = clf_model(**inputs).logits
53
+
54
+ probs = torch.sigmoid(logits)[0].cpu().numpy()
55
+
56
+ labels = []
57
+ for i, p in enumerate(probs):
58
+ if p >= threshold:
59
+ labels.append(clf_model.config.id2label[i])
60
+
61
+ return labels
62
+
63
+
64
+ def semantic_search(query, k=10):
65
+ query_emb = encoder.encode(query, convert_to_numpy=True)
66
+ scores, samples = train_ds.get_nearest_examples(
67
+ "embedding",
68
+ query_emb,
69
+ k=k
70
+ )
71
+ return scores, samples
72
+
73
+
74
+ def hybrid_search(query, alpha=0.7, beta=0.3, max_results=5):
75
+ sem_scores, sem_results = semantic_search(query, k=15)
76
+ predicted_labels = set(predict_labels(query))
77
+
78
+ seen = set()
79
+ results = []
80
+
81
+ for i in range(len(sem_scores)):
82
+ issue_id = sem_results["issue_number"][i]
83
+ if issue_id in seen:
84
+ continue
85
+ seen.add(issue_id)
86
+
87
+ issue_labels = set(sem_results["labels"][i])
88
+ overlap = (
89
+ len(issue_labels & predicted_labels) / len(issue_labels)
90
+ if issue_labels else 0.0
91
+ )
92
+
93
+ final_score = alpha * float(sem_scores[i]) + beta * overlap
94
+
95
+ results.append({
96
+ "Issue": f"#{issue_id}",
97
+ "Final score": round(final_score, 3),
98
+ "Semantic": round(float(sem_scores[i]), 3),
99
+ "Label overlap": round(overlap, 2),
100
+ "Labels": ", ".join(issue_labels),
101
+ "URL": sem_results["html_url"][i],
102
+ })
103
+
104
+ if len(results) >= max_results:
105
+ break
106
+
107
+ return predicted_labels, results
108
+
109
+
110
+ # -----------------------------
111
+ # GRADIO UI
112
+ # -----------------------------
113
+ def run_search(query):
114
+ if not query.strip():
115
+ return "Please enter a query.", []
116
+
117
+ labels, results = hybrid_search(query)
118
+
119
+ label_text = ", ".join(labels) if labels else "No label confidently predicted"
120
+ return label_text, results
121
+
122
+
123
+ with gr.Blocks(title="GitHub Issue Hybrid Search") as demo:
124
+ gr.Markdown(
125
+ """
126
+ # 🐙 GitHub Issue Hybrid Search & Auto-Label Assistant
127
+
128
+ **Semantic Search (MPNet) + Multilabel Classification (DistilBERT)**
129
+ Precision-first hybrid ranking on real scikit-learn issues.
130
+ """
131
+ )
132
+
133
+ query = gr.Textbox(
134
+ label="Describe the issue",
135
+ placeholder="e.g. RandomForestClassifier crashes when sample_weight is None"
136
+ )
137
+
138
+ btn = gr.Button("Search")
139
+
140
+ predicted = gr.Textbox(label="Predicted labels")
141
+ table = gr.Dataframe(
142
+ headers=["Issue", "Final score", "Semantic", "Label overlap", "Labels", "URL"],
143
+ wrap=True
144
+ )
145
+
146
+ btn.click(
147
+ fn=run_search,
148
+ inputs=query,
149
+ outputs=[predicted, table]
150
+ )
151
+
152
+ demo.launch()