Commit
·
e8f04b1
1
Parent(s):
5eb1b80
Update pipeline progress bar
Browse files- RepoPipeline.py +16 -4
RepoPipeline.py
CHANGED
|
@@ -280,7 +280,7 @@ class RepoPipeline(Pipeline):
|
|
| 280 |
if not text_sets \
|
| 281 |
else torch.cat([self.encode(text, max_length) for text in text_sets], dim=0)
|
| 282 |
|
| 283 |
-
def _forward(self, extracted_infos: List, max_length=512) -> List:
|
| 284 |
"""
|
| 285 |
The method for "forward" period.
|
| 286 |
:param extracted_infos: the information of repositories.
|
|
@@ -289,8 +289,9 @@ class RepoPipeline(Pipeline):
|
|
| 289 |
"""
|
| 290 |
model_outputs = []
|
| 291 |
# The number of repository.
|
| 292 |
-
|
| 293 |
-
|
|
|
|
| 294 |
# For each repository
|
| 295 |
for repo_info in extracted_infos:
|
| 296 |
repo_name = repo_info["name"]
|
|
@@ -307,12 +308,18 @@ class RepoPipeline(Pipeline):
|
|
| 307 |
code_embeddings = self.generate_embeddings(repo_info["codes"], max_length)
|
| 308 |
info["code_embeddings"] = code_embeddings.cpu().numpy()
|
| 309 |
info["mean_code_embedding"] = torch.mean(code_embeddings, dim=0, keepdim=True).cpu().numpy()
|
|
|
|
|
|
|
|
|
|
| 310 |
|
| 311 |
# Doc embeddings
|
| 312 |
tqdm.write(f"[*] Generating doc embeddings for {repo_name}")
|
| 313 |
doc_embeddings = self.generate_embeddings(repo_info["docs"], max_length)
|
| 314 |
info["doc_embeddings"] = doc_embeddings.cpu().numpy()
|
| 315 |
info["mean_doc_embedding"] = torch.mean(doc_embeddings, dim=0, keepdim=True).cpu().numpy()
|
|
|
|
|
|
|
|
|
|
| 316 |
|
| 317 |
# Requirement embeddings
|
| 318 |
tqdm.write(f"[*] Generating requirement embeddings for {repo_name}")
|
|
@@ -320,12 +327,18 @@ class RepoPipeline(Pipeline):
|
|
| 320 |
info["requirement_embeddings"] = requirement_embeddings.cpu().numpy()
|
| 321 |
info["mean_requirement_embedding"] = torch.mean(requirement_embeddings, dim=0,
|
| 322 |
keepdim=True).cpu().numpy()
|
|
|
|
|
|
|
|
|
|
| 323 |
|
| 324 |
# Readme embeddings
|
| 325 |
tqdm.write(f"[*] Generating readme embeddings for {repo_name}")
|
| 326 |
readme_embeddings = self.generate_embeddings(repo_info["readmes"], max_length)
|
| 327 |
info["readme_embeddings"] = readme_embeddings.cpu().numpy()
|
| 328 |
info["mean_readme_embedding"] = torch.mean(readme_embeddings, dim=0, keepdim=True).cpu().numpy()
|
|
|
|
|
|
|
|
|
|
| 329 |
|
| 330 |
# Repo-level mean embedding
|
| 331 |
info["mean_repo_embedding"] = np.concatenate([
|
|
@@ -345,7 +358,6 @@ class RepoPipeline(Pipeline):
|
|
| 345 |
info["mean_readme_embedding_shape"] = info["mean_readme_embedding"].shape
|
| 346 |
info["mean_repo_embedding_shape"] = info["mean_repo_embedding"].shape
|
| 347 |
|
| 348 |
-
progress_bar.update(1)
|
| 349 |
model_outputs.append(info)
|
| 350 |
|
| 351 |
return model_outputs
|
|
|
|
| 280 |
if not text_sets \
|
| 281 |
else torch.cat([self.encode(text, max_length) for text in text_sets], dim=0)
|
| 282 |
|
| 283 |
+
def _forward(self, extracted_infos: List, max_length=512, st_progress=None) -> List:
|
| 284 |
"""
|
| 285 |
The method for "forward" period.
|
| 286 |
:param extracted_infos: the information of repositories.
|
|
|
|
| 289 |
"""
|
| 290 |
model_outputs = []
|
| 291 |
# The number of repository.
|
| 292 |
+
num_texts = sum(
|
| 293 |
+
len(x["codes"]) + len(x["docs"] + len(x["requirements"]) + len(x["readmes"])) for x in extracted_infos)
|
| 294 |
+
with tqdm(total=num_texts) as progress_bar:
|
| 295 |
# For each repository
|
| 296 |
for repo_info in extracted_infos:
|
| 297 |
repo_name = repo_info["name"]
|
|
|
|
| 308 |
code_embeddings = self.generate_embeddings(repo_info["codes"], max_length)
|
| 309 |
info["code_embeddings"] = code_embeddings.cpu().numpy()
|
| 310 |
info["mean_code_embedding"] = torch.mean(code_embeddings, dim=0, keepdim=True).cpu().numpy()
|
| 311 |
+
progress_bar.update(len(repo_info["codes"]))
|
| 312 |
+
if st_progress:
|
| 313 |
+
st_progress.progress(progress_bar.n / progress_bar.total)
|
| 314 |
|
| 315 |
# Doc embeddings
|
| 316 |
tqdm.write(f"[*] Generating doc embeddings for {repo_name}")
|
| 317 |
doc_embeddings = self.generate_embeddings(repo_info["docs"], max_length)
|
| 318 |
info["doc_embeddings"] = doc_embeddings.cpu().numpy()
|
| 319 |
info["mean_doc_embedding"] = torch.mean(doc_embeddings, dim=0, keepdim=True).cpu().numpy()
|
| 320 |
+
progress_bar.update(len(repo_info["docs"]))
|
| 321 |
+
if st_progress:
|
| 322 |
+
st_progress.progress(progress_bar.n / progress_bar.total)
|
| 323 |
|
| 324 |
# Requirement embeddings
|
| 325 |
tqdm.write(f"[*] Generating requirement embeddings for {repo_name}")
|
|
|
|
| 327 |
info["requirement_embeddings"] = requirement_embeddings.cpu().numpy()
|
| 328 |
info["mean_requirement_embedding"] = torch.mean(requirement_embeddings, dim=0,
|
| 329 |
keepdim=True).cpu().numpy()
|
| 330 |
+
progress_bar.update(len(repo_info["requirements"]))
|
| 331 |
+
if st_progress:
|
| 332 |
+
st_progress.progress(progress_bar.n / progress_bar.total)
|
| 333 |
|
| 334 |
# Readme embeddings
|
| 335 |
tqdm.write(f"[*] Generating readme embeddings for {repo_name}")
|
| 336 |
readme_embeddings = self.generate_embeddings(repo_info["readmes"], max_length)
|
| 337 |
info["readme_embeddings"] = readme_embeddings.cpu().numpy()
|
| 338 |
info["mean_readme_embedding"] = torch.mean(readme_embeddings, dim=0, keepdim=True).cpu().numpy()
|
| 339 |
+
progress_bar.update(len(repo_info["readmes"]))
|
| 340 |
+
if st_progress:
|
| 341 |
+
st_progress.progress(progress_bar.n / progress_bar.total)
|
| 342 |
|
| 343 |
# Repo-level mean embedding
|
| 344 |
info["mean_repo_embedding"] = np.concatenate([
|
|
|
|
| 358 |
info["mean_readme_embedding_shape"] = info["mean_readme_embedding"].shape
|
| 359 |
info["mean_repo_embedding_shape"] = info["mean_repo_embedding"].shape
|
| 360 |
|
|
|
|
| 361 |
model_outputs.append(info)
|
| 362 |
|
| 363 |
return model_outputs
|