Add streamlit widget support to the pipeline
Browse files- pipeline.py +31 -8
pipeline.py
CHANGED
|
@@ -6,7 +6,7 @@ from io import BytesIO
|
|
| 6 |
import numpy as np
|
| 7 |
import requests
|
| 8 |
import torch
|
| 9 |
-
from tqdm import tqdm
|
| 10 |
from transformers import Pipeline
|
| 11 |
|
| 12 |
|
|
@@ -96,26 +96,38 @@ def download_and_extract(repos, headers=None):
|
|
| 96 |
|
| 97 |
|
| 98 |
class RepoEmbeddingPipeline(Pipeline):
|
| 99 |
-
def __init__(self, github_token=None, *args, **kwargs):
|
| 100 |
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
self.API_HEADERS = {"Accept": "application/vnd.github+json"}
|
| 102 |
if not github_token:
|
| 103 |
-
|
| 104 |
-
"[
|
| 105 |
-
"For more info, see:"
|
| 106 |
"https://docs.github.com/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token"
|
| 107 |
)
|
|
|
|
|
|
|
|
|
|
| 108 |
else:
|
| 109 |
self.set_github_token(github_token)
|
| 110 |
|
| 111 |
def set_github_token(self, github_token):
|
| 112 |
self.API_HEADERS["Authorization"] = f"Bearer {github_token}"
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
def _sanitize_parameters(self, **kwargs):
|
| 116 |
_forward_kwargs = {}
|
| 117 |
if "max_length" in kwargs:
|
| 118 |
_forward_kwargs["max_length"] = kwargs["max_length"]
|
|
|
|
|
|
|
| 119 |
|
| 120 |
return {}, _forward_kwargs, {}
|
| 121 |
|
|
@@ -123,6 +135,8 @@ class RepoEmbeddingPipeline(Pipeline):
|
|
| 123 |
if isinstance(inputs, str):
|
| 124 |
inputs = (inputs,)
|
| 125 |
|
|
|
|
|
|
|
| 126 |
extracted_infos = download_and_extract(inputs, headers=self.API_HEADERS)
|
| 127 |
|
| 128 |
return extracted_infos
|
|
@@ -153,7 +167,7 @@ class RepoEmbeddingPipeline(Pipeline):
|
|
| 153 |
|
| 154 |
return sentence_embeddings
|
| 155 |
|
| 156 |
-
def _forward(self, extracted_infos, max_length=512):
|
| 157 |
repo_dataset = {}
|
| 158 |
num_texts = sum(
|
| 159 |
len(x["funcs"]) + len(x["docs"]) for x in extracted_infos.values()
|
|
@@ -163,14 +177,20 @@ class RepoEmbeddingPipeline(Pipeline):
|
|
| 163 |
pbar.set_description(f"Processing {repo_name}")
|
| 164 |
entry = {"topics": repo_info.get("topics")}
|
| 165 |
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
code_embeddings = []
|
| 169 |
for func in repo_info["funcs"]:
|
| 170 |
code_embeddings.append(
|
| 171 |
[func, self.encode(func, max_length).squeeze().tolist()]
|
| 172 |
)
|
|
|
|
| 173 |
pbar.update(1)
|
|
|
|
|
|
|
| 174 |
|
| 175 |
entry["code_embeddings"] = code_embeddings
|
| 176 |
entry["mean_code_embedding"] = (
|
|
@@ -184,7 +204,10 @@ class RepoEmbeddingPipeline(Pipeline):
|
|
| 184 |
doc_embeddings.append(
|
| 185 |
[doc, self.encode(doc, max_length).squeeze().tolist()]
|
| 186 |
)
|
|
|
|
| 187 |
pbar.update(1)
|
|
|
|
|
|
|
| 188 |
|
| 189 |
entry["doc_embeddings"] = doc_embeddings
|
| 190 |
entry["mean_doc_embedding"] = (
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
import requests
|
| 8 |
import torch
|
| 9 |
+
from tqdm.auto import tqdm
|
| 10 |
from transformers import Pipeline
|
| 11 |
|
| 12 |
|
|
|
|
| 96 |
|
| 97 |
|
| 98 |
class RepoEmbeddingPipeline(Pipeline):
|
| 99 |
+
def __init__(self, github_token=None, st_messager=None, *args, **kwargs):
|
| 100 |
super().__init__(*args, **kwargs)
|
| 101 |
+
|
| 102 |
+
# Streamlit single element container created by st.empty()
|
| 103 |
+
self.st_messager = st_messager
|
| 104 |
+
|
| 105 |
self.API_HEADERS = {"Accept": "application/vnd.github+json"}
|
| 106 |
if not github_token:
|
| 107 |
+
message = (
|
| 108 |
+
"[*] Consider setting GitHub token to avoid hitting rate limits. \n"
|
| 109 |
+
"For more info, see: "
|
| 110 |
"https://docs.github.com/authentication/keeping-your-account-and-data-secure/creating-a-personal-access-token"
|
| 111 |
)
|
| 112 |
+
print(message)
|
| 113 |
+
if self.st_messager:
|
| 114 |
+
self.st_messager.info(message)
|
| 115 |
else:
|
| 116 |
self.set_github_token(github_token)
|
| 117 |
|
| 118 |
def set_github_token(self, github_token):
|
| 119 |
self.API_HEADERS["Authorization"] = f"Bearer {github_token}"
|
| 120 |
+
message = "[+] GitHub token set"
|
| 121 |
+
print(message)
|
| 122 |
+
if self.st_messager:
|
| 123 |
+
self.st_messager.success(message)
|
| 124 |
|
| 125 |
def _sanitize_parameters(self, **kwargs):
|
| 126 |
_forward_kwargs = {}
|
| 127 |
if "max_length" in kwargs:
|
| 128 |
_forward_kwargs["max_length"] = kwargs["max_length"]
|
| 129 |
+
if "st_progress" in kwargs:
|
| 130 |
+
_forward_kwargs["st_progress"] = kwargs["st_progress"]
|
| 131 |
|
| 132 |
return {}, _forward_kwargs, {}
|
| 133 |
|
|
|
|
| 135 |
if isinstance(inputs, str):
|
| 136 |
inputs = (inputs,)
|
| 137 |
|
| 138 |
+
if self.st_messager:
|
| 139 |
+
self.st_messager.info("[*] Downloading and extracting repos...")
|
| 140 |
extracted_infos = download_and_extract(inputs, headers=self.API_HEADERS)
|
| 141 |
|
| 142 |
return extracted_infos
|
|
|
|
| 167 |
|
| 168 |
return sentence_embeddings
|
| 169 |
|
| 170 |
+
def _forward(self, extracted_infos, max_length=512, st_progress=None):
|
| 171 |
repo_dataset = {}
|
| 172 |
num_texts = sum(
|
| 173 |
len(x["funcs"]) + len(x["docs"]) for x in extracted_infos.values()
|
|
|
|
| 177 |
pbar.set_description(f"Processing {repo_name}")
|
| 178 |
entry = {"topics": repo_info.get("topics")}
|
| 179 |
|
| 180 |
+
message = f"[*] Generating embeddings for {repo_name}"
|
| 181 |
+
tqdm.write(message)
|
| 182 |
+
if self.st_messager:
|
| 183 |
+
self.st_messager.info(message)
|
| 184 |
|
| 185 |
code_embeddings = []
|
| 186 |
for func in repo_info["funcs"]:
|
| 187 |
code_embeddings.append(
|
| 188 |
[func, self.encode(func, max_length).squeeze().tolist()]
|
| 189 |
)
|
| 190 |
+
|
| 191 |
pbar.update(1)
|
| 192 |
+
if st_progress:
|
| 193 |
+
st_progress.progress(pbar.n / pbar.total)
|
| 194 |
|
| 195 |
entry["code_embeddings"] = code_embeddings
|
| 196 |
entry["mean_code_embedding"] = (
|
|
|
|
| 204 |
doc_embeddings.append(
|
| 205 |
[doc, self.encode(doc, max_length).squeeze().tolist()]
|
| 206 |
)
|
| 207 |
+
|
| 208 |
pbar.update(1)
|
| 209 |
+
if st_progress:
|
| 210 |
+
st_progress.progress(pbar.n / pbar.total)
|
| 211 |
|
| 212 |
entry["doc_embeddings"] = doc_embeddings
|
| 213 |
entry["mean_doc_embedding"] = (
|