Show progress bar when running the model
Browse files- pipeline.py +27 -15
pipeline.py
CHANGED
|
@@ -6,6 +6,7 @@ from io import BytesIO
|
|
| 6 |
import numpy as np
|
| 7 |
import requests
|
| 8 |
import torch
|
|
|
|
| 9 |
from transformers import Pipeline
|
| 10 |
|
| 11 |
|
|
@@ -154,26 +155,37 @@ class RepoEmbeddingPipeline(Pipeline):
|
|
| 154 |
|
| 155 |
def _forward(self, extracted_infos, max_length=512):
|
| 156 |
repo_dataset = {}
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
entry["code_embeddings"] = code_embeddings
|
| 167 |
entry["mean_code_embedding"] = (
|
| 168 |
np.mean([x[1] for x in code_embeddings], axis=0).tolist()
|
| 169 |
if code_embeddings
|
| 170 |
else None
|
| 171 |
)
|
| 172 |
-
|
| 173 |
-
doc_embeddings = [
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
| 177 |
entry["doc_embeddings"] = doc_embeddings
|
| 178 |
entry["mean_doc_embedding"] = (
|
| 179 |
np.mean([x[1] for x in doc_embeddings], axis=0).tolist()
|
|
@@ -181,7 +193,7 @@ class RepoEmbeddingPipeline(Pipeline):
|
|
| 181 |
else None
|
| 182 |
)
|
| 183 |
|
| 184 |
-
|
| 185 |
|
| 186 |
return repo_dataset
|
| 187 |
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
import requests
|
| 8 |
import torch
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
from transformers import Pipeline
|
| 11 |
|
| 12 |
|
|
|
|
| 155 |
|
| 156 |
def _forward(self, extracted_infos, max_length=512):
|
| 157 |
repo_dataset = {}
|
| 158 |
+
num_texts = sum(
|
| 159 |
+
len(x["funcs"]) + len(x["docs"]) for x in extracted_infos.values()
|
| 160 |
+
)
|
| 161 |
+
with tqdm(total=num_texts) as pbar:
|
| 162 |
+
for repo_name, repo_info in extracted_infos.items():
|
| 163 |
+
pbar.set_description(f"Processing {repo_name}")
|
| 164 |
+
entry = {"topics": repo_info.get("topics")}
|
| 165 |
+
|
| 166 |
+
print(f"[+] Generating embeddings for {repo_name}")
|
| 167 |
+
|
| 168 |
+
code_embeddings = []
|
| 169 |
+
for func in repo_info["funcs"]:
|
| 170 |
+
code_embeddings.append(
|
| 171 |
+
[func, self.encode(func, max_length).squeeze().tolist()]
|
| 172 |
+
)
|
| 173 |
+
pbar.update(1)
|
| 174 |
+
|
| 175 |
entry["code_embeddings"] = code_embeddings
|
| 176 |
entry["mean_code_embedding"] = (
|
| 177 |
np.mean([x[1] for x in code_embeddings], axis=0).tolist()
|
| 178 |
if code_embeddings
|
| 179 |
else None
|
| 180 |
)
|
| 181 |
+
|
| 182 |
+
doc_embeddings = []
|
| 183 |
+
for doc in repo_info["docs"]:
|
| 184 |
+
doc_embeddings.append(
|
| 185 |
+
[doc, self.encode(doc, max_length).squeeze().tolist()]
|
| 186 |
+
)
|
| 187 |
+
pbar.update(1)
|
| 188 |
+
|
| 189 |
entry["doc_embeddings"] = doc_embeddings
|
| 190 |
entry["mean_doc_embedding"] = (
|
| 191 |
np.mean([x[1] for x in doc_embeddings], axis=0).tolist()
|
|
|
|
| 193 |
else None
|
| 194 |
)
|
| 195 |
|
| 196 |
+
repo_dataset[repo_name] = entry
|
| 197 |
|
| 198 |
return repo_dataset
|
| 199 |
|