| import sys |
| from tqdm import tqdm |
| from Levenshtein import distance |
|
|
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser("Find duplicates in the dataset ASM") |
| parser.add_argument("--train", required=True) |
| |
| parser.add_argument("--test", required=True) |
| parser.add_argument("--result", required=False) |
| parser.add_argument("--distance", action="store_true", default=False) |
| args = parser.parse_args() |
| |
| train = [] |
| train_hash = {} |
| |
| test = [] |
| with open(args.train, "r") as tf: |
| for idx, line in tqdm(enumerate(tf), desc="Read train", leave=False): |
| train_hash[hash(line)] = idx |
| comps = line.strip().split(" ") |
| train.append(comps) |
| |
| |
| |
| with open(args.test, "r") as tf: |
| for line in tqdm(tf, desc="Read test", leave=False): |
| test.append(line) |
|
|
| selfcheck = args.train == args.test |
| if args.result: |
| rf = open(args.result, "w") |
| searchdist = args.distance |
| else: |
| searchdist = False |
| rf = None |
|
|
| def reswrite(s): |
| if rf: |
| rf.write(s) |
|
|
| exact = 0 |
| for i, testline in tqdm(enumerate(test), desc="Test", total=len(test)): |
| testl = testline.strip().split(" ") |
| htest = hash(testline) |
| if htest in train_hash: |
| |
| j = train_hash[htest] |
| if not selfcheck or j != i: |
| exact += 1 |
| reswrite(f"{i} {j} 0 0.0\n") |
| continue |
|
|
| |
| if searchdist: |
| minavgdist, mindist, minj = 100, 100, -1 |
| for j, trainl in enumerate(train): |
| if abs(len(trainl) - len(testl)) > 10: |
| dist = abs(len(trainl) - len(testl)) * 2 |
| else: |
| dist = distance(trainl, testl) |
| avgdist = dist / (len(trainl) + len(testl)) |
| if mindist > dist: |
| minavgdist, mindist, minj = avgdist, dist, j |
|
|
| reswrite(f"{i} {minj} {mindist} {minavgdist}\n") |
|
|
| print("Exact duplicates:", exact) |
|
|