File size: 4,149 Bytes
3a788dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import glob
import json
import os
import re
from typing import Dict, List, Tuple

import jsonlines
import numpy as np
import pandas as pd
from sklearn.metrics import pairwise_distances
from torch.utils.data import DataLoader

from bytecover.models.data_loader import bytecover_dataloader
from bytecover.models.data_model import Postfix


def dataloader_factory(config: Dict, data_split: str) -> List[DataLoader]:
    seq_len_key = "max_seq_len"

    if data_split == "TRAIN":
        t_loaders = []
        for L in config["train"][seq_len_key]:
            t_loaders.append(
                bytecover_dataloader(
                    data_path=config["data_path"],
                    file_ext=config["file_extension"],
                    dataset_path=config["dataset_path"],
                    data_split=data_split,
                    debug=config["debug"],
                    max_len=L,
                    **config["train"],
                )
            )
        return t_loaders
    L = config[data_split.lower()][seq_len_key]
    return [
        bytecover_dataloader(
            data_path=config["data_path"],
            file_ext=config["file_extension"],
            dataset_path=config["dataset_path"],
            data_split=data_split,
            debug=config["debug"],
            max_len=L,
            **config[data_split.lower()],
        )
    ]


def validation_triplet_sampling(anchor_id: str, val_ids: List[str], df: pd.DataFrame) -> Dict[str, int]:
    np.random.shuffle(df.loc[anchor_id, "versions"])
    pos_list = np.setdiff1d(df.loc[anchor_id, "versions"], anchor_id)
    pos_id = np.random.choice(pos_list, 1)[0]
    pos_id = val_ids.index(pos_id)

    neg_id = df.loc[~df.index.isin([anchor_id] + list(pos_list))].sample(1).index[0]
    neg_id = val_ids.index(neg_id)

    return dict(pos_id=pos_id, neg_id=neg_id)


def calculate_ranking_metrics(embeddings: np.ndarray, cliques: List[int]) -> Tuple[np.ndarray, np.ndarray]:
    distances = pairwise_distances(embeddings)
    s_distances = np.argsort(distances, axis=1)
    cliques = np.array(cliques)
    query_cliques = cliques[s_distances[:, 0]]
    search_cliques = cliques[s_distances[:, 1:]]

    query_cliques = np.tile(query_cliques, (search_cliques.shape[-1], 1)).T
    mask = np.equal(search_cliques, query_cliques)

    ranks = mask.argmax(axis=1)

    cumsum = np.cumsum(mask, axis=1)
    mask2 = mask * cumsum
    mask2 = mask2 / np.arange(1, mask2.shape[-1] + 1)
    average_precisions = np.sum(mask2, axis=1) / np.sum(mask, axis=1)

    return (ranks, average_precisions)


def dir_checker(output_dir: str) -> str:
    output_dir = re.sub(r"run-[0-9]+/*", "", output_dir)
    runs = glob.glob(os.path.join(output_dir, "run-*"))
    if runs != []:
        max_run = max(map(lambda x: int(x.split("-")[-1]), runs))
        run = max_run + 1
    else:
        run = 0
    outdir = os.path.join(output_dir, f"run-{run}")
    return outdir


def save_predictions(outputs: Dict[str, np.ndarray], output_dir: str) -> None:
    os.makedirs(output_dir, exist_ok=True)
    for key in outputs:
        if "_ids" in key:
            with jsonlines.open(os.path.join(output_dir, f"{key}.jsonl"), "w") as f:
                if len(outputs[key][0]) == 4:
                    for clique, anchor, pos, neg in outputs[key]:
                        f.write({"clique_id": clique, "anchor_id": anchor, "positive_id": pos, "negative_id": neg})
                else:
                    for clique, anchor in outputs[key]:
                        f.write({"clique_id": clique, "anchor_id": anchor})
        else:
            np.save(os.path.join(output_dir, f"{key}.npy"), outputs[key])


def save_logs(outputs: dict, output_dir: str, name: str = "log") -> None:
    os.makedirs(output_dir, exist_ok=True)
    log_file = os.path.join(output_dir, f"{name}.jsonl")
    with jsonlines.open(log_file, "a") as f:
        f.write(outputs)


def save_best_log(outputs: Postfix, output_dir: str) -> None:
    os.makedirs(output_dir, exist_ok=True)
    log_file = os.path.join(output_dir, "best-log.json")
    with open(log_file, "w") as f:
        json.dump(outputs, f, indent=2)