Spaces:
Build error
Build error
| import glob | |
| import json | |
| import os | |
| import re | |
| from typing import Dict, List, Tuple | |
| import jsonlines | |
| import numpy as np | |
| import pandas as pd | |
| from sklearn.metrics import pairwise_distances | |
| from torch.utils.data import DataLoader | |
| from bytecover.models.data_loader import bytecover_dataloader | |
| from bytecover.models.data_model import Postfix | |
| def dataloader_factory(config: Dict, data_split: str) -> List[DataLoader]: | |
| seq_len_key = "max_seq_len" | |
| if data_split == "TRAIN": | |
| t_loaders = [] | |
| for L in config["train"][seq_len_key]: | |
| t_loaders.append( | |
| bytecover_dataloader( | |
| data_path=config["data_path"], | |
| file_ext=config["file_extension"], | |
| dataset_path=config["dataset_path"], | |
| data_split=data_split, | |
| debug=config["debug"], | |
| max_len=L, | |
| **config["train"], | |
| ) | |
| ) | |
| return t_loaders | |
| L = config[data_split.lower()][seq_len_key] | |
| return [ | |
| bytecover_dataloader( | |
| data_path=config["data_path"], | |
| file_ext=config["file_extension"], | |
| dataset_path=config["dataset_path"], | |
| data_split=data_split, | |
| debug=config["debug"], | |
| max_len=L, | |
| **config[data_split.lower()], | |
| ) | |
| ] | |
| def validation_triplet_sampling(anchor_id: str, val_ids: List[str], df: pd.DataFrame) -> Dict[str, int]: | |
| np.random.shuffle(df.loc[anchor_id, "versions"]) | |
| pos_list = np.setdiff1d(df.loc[anchor_id, "versions"], anchor_id) | |
| pos_id = np.random.choice(pos_list, 1)[0] | |
| pos_id = val_ids.index(pos_id) | |
| neg_id = df.loc[~df.index.isin([anchor_id] + list(pos_list))].sample(1).index[0] | |
| neg_id = val_ids.index(neg_id) | |
| return dict(pos_id=pos_id, neg_id=neg_id) | |
| def calculate_ranking_metrics(embeddings: np.ndarray, cliques: List[int]) -> Tuple[np.ndarray, np.ndarray]: | |
| distances = pairwise_distances(embeddings) | |
| s_distances = np.argsort(distances, axis=1) | |
| cliques = np.array(cliques) | |
| query_cliques = cliques[s_distances[:, 0]] | |
| search_cliques = cliques[s_distances[:, 1:]] | |
| query_cliques = np.tile(query_cliques, (search_cliques.shape[-1], 1)).T | |
| mask = np.equal(search_cliques, query_cliques) | |
| ranks = mask.argmax(axis=1) | |
| cumsum = np.cumsum(mask, axis=1) | |
| mask2 = mask * cumsum | |
| mask2 = mask2 / np.arange(1, mask2.shape[-1] + 1) | |
| average_precisions = np.sum(mask2, axis=1) / np.sum(mask, axis=1) | |
| return (ranks, average_precisions) | |
| def dir_checker(output_dir: str) -> str: | |
| output_dir = re.sub(r"run-[0-9]+/*", "", output_dir) | |
| runs = glob.glob(os.path.join(output_dir, "run-*")) | |
| if runs != []: | |
| max_run = max(map(lambda x: int(x.split("-")[-1]), runs)) | |
| run = max_run + 1 | |
| else: | |
| run = 0 | |
| outdir = os.path.join(output_dir, f"run-{run}") | |
| return outdir | |
| def save_predictions(outputs: Dict[str, np.ndarray], output_dir: str) -> None: | |
| os.makedirs(output_dir, exist_ok=True) | |
| for key in outputs: | |
| if "_ids" in key: | |
| with jsonlines.open(os.path.join(output_dir, f"{key}.jsonl"), "w") as f: | |
| if len(outputs[key][0]) == 4: | |
| for clique, anchor, pos, neg in outputs[key]: | |
| f.write({"clique_id": clique, "anchor_id": anchor, "positive_id": pos, "negative_id": neg}) | |
| else: | |
| for clique, anchor in outputs[key]: | |
| f.write({"clique_id": clique, "anchor_id": anchor}) | |
| else: | |
| np.save(os.path.join(output_dir, f"{key}.npy"), outputs[key]) | |
| def save_logs(outputs: dict, output_dir: str, name: str = "log") -> None: | |
| os.makedirs(output_dir, exist_ok=True) | |
| log_file = os.path.join(output_dir, f"{name}.jsonl") | |
| with jsonlines.open(log_file, "a") as f: | |
| f.write(outputs) | |
| def save_best_log(outputs: Postfix, output_dir: str) -> None: | |
| os.makedirs(output_dir, exist_ok=True) | |
| log_file = os.path.join(output_dir, "best-log.json") | |
| with open(log_file, "w") as f: | |
| json.dump(outputs, f, indent=2) | |