| """ |
| Generate glycan-cold splits for the combined Bertint dataset. |
| Clusters glycans by WURCS string edit distance, assigns to train/val/test. |
| """ |
| import argparse |
| import json |
| import csv |
| import random |
| from typing import Dict, List, Set, Tuple |
| from collections import defaultdict |
|
|
| def levenshtein_ratio(s1: str, s2: str) -> float: |
| """Compute normalized Levenshtein distance [0,1].""" |
| if s1 == s2: |
| return 0.0 |
| len1, len2 = len(s1), len(s2) |
| if len1 == 0 or len2 == 0: |
| return 1.0 |
| prev_row = list(range(len2 + 1)) |
| curr_row = [0] * (len2 + 1) |
| for i in range(1, len1 + 1): |
| curr_row[0] = i |
| for j in range(1, len2 + 1): |
| cost = 0 if s1[i-1] == s2[j-1] else 1 |
| curr_row[j] = min(curr_row[j-1]+1, prev_row[j]+1, prev_row[j-1]+cost) |
| prev_row, curr_row = curr_row, prev_row |
| return prev_row[len2] / max(len1, len2) |
|
|
| def simple_cluster(glycans: List[str], threshold: float = 0.5, max_pw: int = 5000) -> List[List[str]]: |
| """Cluster glycans by WURCS similarity using union-find.""" |
| n = len(glycans) |
| print(f" Clustering {n} glycans (threshold={threshold})...") |
| parent = list(range(n)) |
| def find(x): |
| while parent[x] != x: |
| parent[x] = parent[parent[x]] |
| x = parent[x] |
| return x |
| def union(x, y): |
| px, py = find(x), find(y) |
| if px != py: |
| parent[px] = py |
| if n <= max_pw: |
| total = n*(n-1)//2; done = 0 |
| for i in range(n): |
| for j in range(i+1, n): |
| if levenshtein_ratio(glycans[i], glycans[j]) < threshold: |
| union(i, j) |
| done += 1 |
| if done % 500000 == 0: |
| print(f" {done}/{total} pairs ({done/total*100:.1f}%)...") |
| else: |
| print(f" Using prefix grouping (n={n} > {max_pw})") |
| groups = defaultdict(list) |
| for i, g in enumerate(glycans): |
| parts = g.split("/") |
| prefix = "/".join(parts[:3]) if len(parts) >= 3 else g[:50] |
| groups[prefix].append(i) |
| for indices in groups.values(): |
| for a in range(len(indices)): |
| for b in range(a+1, len(indices)): |
| if levenshtein_ratio(glycans[indices[a]], glycans[indices[b]]) < threshold: |
| union(indices[a], indices[b]) |
| clusters_map = defaultdict(list) |
| for i in range(n): |
| clusters_map[find(i)].append(glycans[i]) |
| clusters = list(clusters_map.values()) |
| sizes = [len(c) for c in clusters] |
| print(f" {len(clusters)} clusters (largest={max(sizes)}, median={sorted(sizes)[len(sizes)//2]})") |
| return clusters |
|
|
| def assign_splits(clusters, ratios=(0.7,0.15,0.15), seed=42): |
| """Assign clusters to train/val/test.""" |
| random.seed(seed) |
| indexed = list(enumerate(clusters)) |
| random.shuffle(indexed) |
| indexed.sort(key=lambda x: len(x[1]), reverse=True) |
| total = sum(len(c) for c in clusters) |
| train_t, val_t = ratios[0]*total, ratios[1]*total |
| splits = {"train_glycans":[], "val_glycans":[], "test_glycans":[]} |
| counts = {"train":0, "val":0, "test":0} |
| for _, cluster in indexed: |
| tg = train_t - counts["train"] |
| vg = val_t - counts["val"] |
| if tg >= vg and tg > 0: |
| splits["train_glycans"].extend(cluster); counts["train"] += len(cluster) |
| elif vg > 0: |
| splits["val_glycans"].extend(cluster); counts["val"] += len(cluster) |
| else: |
| splits["test_glycans"].extend(cluster); counts["test"] += len(cluster) |
| for k, v in splits.items(): |
| print(f" {k}: {len(v)} ({len(v)/total*100:.1f}%)") |
| return splits |
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--csv", required=True) |
| parser.add_argument("--output", required=True) |
| parser.add_argument("--threshold", type=float, default=0.5) |
| parser.add_argument("--seed", type=int, default=42) |
| args = parser.parse_args() |
| glycans = set() |
| with open(args.csv) as f: |
| for row in csv.DictReader(f): |
| glycans.add(row["glycan_wurcs"]) |
| glycan_list = sorted(glycans) |
| print(f" {len(glycan_list)} unique glycans") |
| clusters = simple_cluster(glycan_list, args.threshold) |
| splits = assign_splits(clusters, seed=args.seed) |
| ts = set(splits["train_glycans"]); vs = set(splits["val_glycans"]); es = set(splits["test_glycans"]) |
| assert not (ts & vs) and not (ts & es) and not (vs & es), "Overlap!" |
| assert len(ts|vs|es) == len(glycan_list), "Missing glycans!" |
| print(f" No overlap, all {len(glycan_list)} assigned") |
| with open(args.output, "w") as f: |
| json.dump(splits, f, indent=2) |
| print(f" Saved to {args.output}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|