""" Generate glycan-cold splits for the combined Bertint dataset. Clusters glycans by WURCS string edit distance, assigns to train/val/test. """ import argparse import json import csv import random from typing import Dict, List, Set, Tuple from collections import defaultdict def levenshtein_ratio(s1: str, s2: str) -> float: """Compute normalized Levenshtein distance [0,1].""" if s1 == s2: return 0.0 len1, len2 = len(s1), len(s2) if len1 == 0 or len2 == 0: return 1.0 prev_row = list(range(len2 + 1)) curr_row = [0] * (len2 + 1) for i in range(1, len1 + 1): curr_row[0] = i for j in range(1, len2 + 1): cost = 0 if s1[i-1] == s2[j-1] else 1 curr_row[j] = min(curr_row[j-1]+1, prev_row[j]+1, prev_row[j-1]+cost) prev_row, curr_row = curr_row, prev_row return prev_row[len2] / max(len1, len2) def simple_cluster(glycans: List[str], threshold: float = 0.5, max_pw: int = 5000) -> List[List[str]]: """Cluster glycans by WURCS similarity using union-find.""" n = len(glycans) print(f" Clustering {n} glycans (threshold={threshold})...") parent = list(range(n)) def find(x): while parent[x] != x: parent[x] = parent[parent[x]] x = parent[x] return x def union(x, y): px, py = find(x), find(y) if px != py: parent[px] = py if n <= max_pw: total = n*(n-1)//2; done = 0 for i in range(n): for j in range(i+1, n): if levenshtein_ratio(glycans[i], glycans[j]) < threshold: union(i, j) done += 1 if done % 500000 == 0: print(f" {done}/{total} pairs ({done/total*100:.1f}%)...") else: print(f" Using prefix grouping (n={n} > {max_pw})") groups = defaultdict(list) for i, g in enumerate(glycans): parts = g.split("/") prefix = "/".join(parts[:3]) if len(parts) >= 3 else g[:50] groups[prefix].append(i) for indices in groups.values(): for a in range(len(indices)): for b in range(a+1, len(indices)): if levenshtein_ratio(glycans[indices[a]], glycans[indices[b]]) < threshold: union(indices[a], indices[b]) clusters_map = defaultdict(list) for i in range(n): clusters_map[find(i)].append(glycans[i]) clusters = list(clusters_map.values()) sizes = [len(c) for c in clusters] print(f" {len(clusters)} clusters (largest={max(sizes)}, median={sorted(sizes)[len(sizes)//2]})") return clusters def assign_splits(clusters, ratios=(0.7,0.15,0.15), seed=42): """Assign clusters to train/val/test.""" random.seed(seed) indexed = list(enumerate(clusters)) random.shuffle(indexed) indexed.sort(key=lambda x: len(x[1]), reverse=True) total = sum(len(c) for c in clusters) train_t, val_t = ratios[0]*total, ratios[1]*total splits = {"train_glycans":[], "val_glycans":[], "test_glycans":[]} counts = {"train":0, "val":0, "test":0} for _, cluster in indexed: tg = train_t - counts["train"] vg = val_t - counts["val"] if tg >= vg and tg > 0: splits["train_glycans"].extend(cluster); counts["train"] += len(cluster) elif vg > 0: splits["val_glycans"].extend(cluster); counts["val"] += len(cluster) else: splits["test_glycans"].extend(cluster); counts["test"] += len(cluster) for k, v in splits.items(): print(f" {k}: {len(v)} ({len(v)/total*100:.1f}%)") return splits def main(): parser = argparse.ArgumentParser() parser.add_argument("--csv", required=True) parser.add_argument("--output", required=True) parser.add_argument("--threshold", type=float, default=0.5) parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() glycans = set() with open(args.csv) as f: for row in csv.DictReader(f): glycans.add(row["glycan_wurcs"]) glycan_list = sorted(glycans) print(f" {len(glycan_list)} unique glycans") clusters = simple_cluster(glycan_list, args.threshold) splits = assign_splits(clusters, seed=args.seed) ts = set(splits["train_glycans"]); vs = set(splits["val_glycans"]); es = set(splits["test_glycans"]) assert not (ts & vs) and not (ts & es) and not (vs & es), "Overlap!" assert len(ts|vs|es) == len(glycan_list), "Missing glycans!" print(f" No overlap, all {len(glycan_list)} assigned") with open(args.output, "w") as f: json.dump(splits, f, indent=2) print(f" Saved to {args.output}") if __name__ == "__main__": main()