Update RepoPipeline.py
Browse files- RepoPipeline.py +12 -8
RepoPipeline.py
CHANGED
|
@@ -126,8 +126,8 @@ def extract_information(repos, headers=None):
|
|
| 126 |
)
|
| 127 |
except SyntaxError as e:
|
| 128 |
tqdm.write(f"[-] SyntaxError in {member.name}, skipping: \n{e}")
|
|
|
|
| 129 |
elif (member.name.endswith("README.md") or member.name.endswith("README.rst")) and member.isfile():
|
| 130 |
-
# 3. Extracting readme.
|
| 131 |
try:
|
| 132 |
file_content = tar.extractfile(member).read().decode("utf-8")
|
| 133 |
# extract readme
|
|
@@ -140,8 +140,8 @@ def extract_information(repos, headers=None):
|
|
| 140 |
)
|
| 141 |
except SyntaxError as e:
|
| 142 |
tqdm.write(f"[-] SyntaxError in {member.name}, skipping: \n{e}")
|
|
|
|
| 143 |
elif member.name.endswith("requirements.txt") and member.isfile():
|
| 144 |
-
# 4. Extracting requirements.
|
| 145 |
try:
|
| 146 |
lines = tar.extractfile(member).readlines().decode("utf-8")
|
| 147 |
# extract readme
|
|
@@ -290,25 +290,26 @@ class RepoPipeline(Pipeline):
|
|
| 290 |
tqdm.write(f"[*] Generating code embeddings for {repo_name}")
|
| 291 |
code_embeddings = self.generate_embeddings(repo_info["codes"], max_length)
|
| 292 |
info["code_embeddings"] = code_embeddings.cpu().numpy()
|
| 293 |
-
info["mean_code_embedding"] = torch.mean(code_embeddings, dim=0).cpu().numpy()
|
| 294 |
|
| 295 |
# Doc embeddings
|
| 296 |
tqdm.write(f"[*] Generating doc embeddings for {repo_name}")
|
| 297 |
doc_embeddings = self.generate_embeddings(repo_info["docs"], max_length)
|
| 298 |
info["doc_embeddings"] = doc_embeddings.cpu().numpy()
|
| 299 |
-
info["mean_doc_embedding"] = torch.mean(doc_embeddings, dim=0).cpu().numpy()
|
| 300 |
|
| 301 |
# Requirement embeddings
|
| 302 |
tqdm.write(f"[*] Generating requirement embeddings for {repo_name}")
|
| 303 |
requirement_embeddings = self.generate_embeddings(repo_info["requirements"], max_length)
|
| 304 |
info["requirement_embeddings"] = requirement_embeddings.cpu().numpy()
|
| 305 |
-
info["mean_requirement_embedding"] = torch.mean(requirement_embeddings, dim=0
|
|
|
|
| 306 |
|
| 307 |
# Readme embeddings
|
| 308 |
tqdm.write(f"[*] Generating readme embeddings for {repo_name}")
|
| 309 |
readme_embeddings = self.generate_embeddings(repo_info["readmes"], max_length)
|
| 310 |
info["readme_embeddings"] = readme_embeddings.cpu().numpy()
|
| 311 |
-
info["mean_readme_embedding"] = torch.mean(readme_embeddings, dim=0).cpu().numpy()
|
| 312 |
|
| 313 |
# Repo-level mean embedding
|
| 314 |
info["mean_repo_embedding"] = np.concatenate([
|
|
@@ -316,13 +317,16 @@ class RepoPipeline(Pipeline):
|
|
| 316 |
info["mean_doc_embedding"],
|
| 317 |
info["mean_requirement_embedding"],
|
| 318 |
info["mean_readme_embedding"]
|
| 319 |
-
], axis=0)
|
| 320 |
|
| 321 |
-
# TODO Remove test
|
| 322 |
info["code_embeddings_shape"] = info["code_embeddings"].shape
|
|
|
|
| 323 |
info["doc_embeddings_shape"] = info["doc_embeddings"].shape
|
|
|
|
| 324 |
info["requirement_embeddings_shape"] = info["requirement_embeddings"].shape
|
|
|
|
| 325 |
info["readme_embeddings_shape"] = info["readme_embeddings"].shape
|
|
|
|
| 326 |
info["mean_repo_embedding_shape"] = info["mean_repo_embedding"].shape
|
| 327 |
|
| 328 |
progress_bar.update(1)
|
|
|
|
| 126 |
)
|
| 127 |
except SyntaxError as e:
|
| 128 |
tqdm.write(f"[-] SyntaxError in {member.name}, skipping: \n{e}")
|
| 129 |
+
# 3. Extracting readme.
|
| 130 |
elif (member.name.endswith("README.md") or member.name.endswith("README.rst")) and member.isfile():
|
|
|
|
| 131 |
try:
|
| 132 |
file_content = tar.extractfile(member).read().decode("utf-8")
|
| 133 |
# extract readme
|
|
|
|
| 140 |
)
|
| 141 |
except SyntaxError as e:
|
| 142 |
tqdm.write(f"[-] SyntaxError in {member.name}, skipping: \n{e}")
|
| 143 |
+
# 4. Extracting requirements.
|
| 144 |
elif member.name.endswith("requirements.txt") and member.isfile():
|
|
|
|
| 145 |
try:
|
| 146 |
lines = tar.extractfile(member).readlines().decode("utf-8")
|
| 147 |
# extract readme
|
|
|
|
| 290 |
tqdm.write(f"[*] Generating code embeddings for {repo_name}")
|
| 291 |
code_embeddings = self.generate_embeddings(repo_info["codes"], max_length)
|
| 292 |
info["code_embeddings"] = code_embeddings.cpu().numpy()
|
| 293 |
+
info["mean_code_embedding"] = torch.mean(code_embeddings, dim=0, keepdim=True).cpu().numpy()
|
| 294 |
|
| 295 |
# Doc embeddings
|
| 296 |
tqdm.write(f"[*] Generating doc embeddings for {repo_name}")
|
| 297 |
doc_embeddings = self.generate_embeddings(repo_info["docs"], max_length)
|
| 298 |
info["doc_embeddings"] = doc_embeddings.cpu().numpy()
|
| 299 |
+
info["mean_doc_embedding"] = torch.mean(doc_embeddings, dim=0, keepdim=True).cpu().numpy()
|
| 300 |
|
| 301 |
# Requirement embeddings
|
| 302 |
tqdm.write(f"[*] Generating requirement embeddings for {repo_name}")
|
| 303 |
requirement_embeddings = self.generate_embeddings(repo_info["requirements"], max_length)
|
| 304 |
info["requirement_embeddings"] = requirement_embeddings.cpu().numpy()
|
| 305 |
+
info["mean_requirement_embedding"] = torch.mean(requirement_embeddings, dim=0,
|
| 306 |
+
keepdim=True).cpu().numpy()
|
| 307 |
|
| 308 |
# Readme embeddings
|
| 309 |
tqdm.write(f"[*] Generating readme embeddings for {repo_name}")
|
| 310 |
readme_embeddings = self.generate_embeddings(repo_info["readmes"], max_length)
|
| 311 |
info["readme_embeddings"] = readme_embeddings.cpu().numpy()
|
| 312 |
+
info["mean_readme_embedding"] = torch.mean(readme_embeddings, dim=0, keepdim=True).cpu().numpy()
|
| 313 |
|
| 314 |
# Repo-level mean embedding
|
| 315 |
info["mean_repo_embedding"] = np.concatenate([
|
|
|
|
| 317 |
info["mean_doc_embedding"],
|
| 318 |
info["mean_requirement_embedding"],
|
| 319 |
info["mean_readme_embedding"]
|
| 320 |
+
], axis=0).reshape(1, -1)
|
| 321 |
|
|
|
|
| 322 |
info["code_embeddings_shape"] = info["code_embeddings"].shape
|
| 323 |
+
info["mean_code_embedding_shape"] = info["mean_code_embedding"].shape
|
| 324 |
info["doc_embeddings_shape"] = info["doc_embeddings"].shape
|
| 325 |
+
info["mean_doc_embedding_shape"] = info["mean_doc_embedding"].shape
|
| 326 |
info["requirement_embeddings_shape"] = info["requirement_embeddings"].shape
|
| 327 |
+
info["mean_requirement_embedding_shape"] = info["mean_requirement_embedding"].shape
|
| 328 |
info["readme_embeddings_shape"] = info["readme_embeddings"].shape
|
| 329 |
+
info["mean_readme_embedding_shape"] = info["mean_readme_embedding"].shape
|
| 330 |
info["mean_repo_embedding_shape"] = info["mean_repo_embedding"].shape
|
| 331 |
|
| 332 |
progress_bar.update(1)
|