bertose-affinose-training-code / code /bertint /generate_glycan_splits.py
supanthadey1's picture
Add BERTose and AFFINose training code release
1d6f391 verified
Raw
History Blame Contribute Delete
4.76 kB
"""
Generate glycan-cold splits for the combined Bertint dataset.
Clusters glycans by WURCS string edit distance, assigns to train/val/test.
"""
import argparse
import json
import csv
import random
from typing import Dict, List, Set, Tuple
from collections import defaultdict
def levenshtein_ratio(s1: str, s2: str) -> float:
"""Compute normalized Levenshtein distance [0,1]."""
if s1 == s2:
return 0.0
len1, len2 = len(s1), len(s2)
if len1 == 0 or len2 == 0:
return 1.0
prev_row = list(range(len2 + 1))
curr_row = [0] * (len2 + 1)
for i in range(1, len1 + 1):
curr_row[0] = i
for j in range(1, len2 + 1):
cost = 0 if s1[i-1] == s2[j-1] else 1
curr_row[j] = min(curr_row[j-1]+1, prev_row[j]+1, prev_row[j-1]+cost)
prev_row, curr_row = curr_row, prev_row
return prev_row[len2] / max(len1, len2)
def simple_cluster(glycans: List[str], threshold: float = 0.5, max_pw: int = 5000) -> List[List[str]]:
"""Cluster glycans by WURCS similarity using union-find."""
n = len(glycans)
print(f" Clustering {n} glycans (threshold={threshold})...")
parent = list(range(n))
def find(x):
while parent[x] != x:
parent[x] = parent[parent[x]]
x = parent[x]
return x
def union(x, y):
px, py = find(x), find(y)
if px != py:
parent[px] = py
if n <= max_pw:
total = n*(n-1)//2; done = 0
for i in range(n):
for j in range(i+1, n):
if levenshtein_ratio(glycans[i], glycans[j]) < threshold:
union(i, j)
done += 1
if done % 500000 == 0:
print(f" {done}/{total} pairs ({done/total*100:.1f}%)...")
else:
print(f" Using prefix grouping (n={n} > {max_pw})")
groups = defaultdict(list)
for i, g in enumerate(glycans):
parts = g.split("/")
prefix = "/".join(parts[:3]) if len(parts) >= 3 else g[:50]
groups[prefix].append(i)
for indices in groups.values():
for a in range(len(indices)):
for b in range(a+1, len(indices)):
if levenshtein_ratio(glycans[indices[a]], glycans[indices[b]]) < threshold:
union(indices[a], indices[b])
clusters_map = defaultdict(list)
for i in range(n):
clusters_map[find(i)].append(glycans[i])
clusters = list(clusters_map.values())
sizes = [len(c) for c in clusters]
print(f" {len(clusters)} clusters (largest={max(sizes)}, median={sorted(sizes)[len(sizes)//2]})")
return clusters
def assign_splits(clusters, ratios=(0.7,0.15,0.15), seed=42):
"""Assign clusters to train/val/test."""
random.seed(seed)
indexed = list(enumerate(clusters))
random.shuffle(indexed)
indexed.sort(key=lambda x: len(x[1]), reverse=True)
total = sum(len(c) for c in clusters)
train_t, val_t = ratios[0]*total, ratios[1]*total
splits = {"train_glycans":[], "val_glycans":[], "test_glycans":[]}
counts = {"train":0, "val":0, "test":0}
for _, cluster in indexed:
tg = train_t - counts["train"]
vg = val_t - counts["val"]
if tg >= vg and tg > 0:
splits["train_glycans"].extend(cluster); counts["train"] += len(cluster)
elif vg > 0:
splits["val_glycans"].extend(cluster); counts["val"] += len(cluster)
else:
splits["test_glycans"].extend(cluster); counts["test"] += len(cluster)
for k, v in splits.items():
print(f" {k}: {len(v)} ({len(v)/total*100:.1f}%)")
return splits
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--csv", required=True)
parser.add_argument("--output", required=True)
parser.add_argument("--threshold", type=float, default=0.5)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
glycans = set()
with open(args.csv) as f:
for row in csv.DictReader(f):
glycans.add(row["glycan_wurcs"])
glycan_list = sorted(glycans)
print(f" {len(glycan_list)} unique glycans")
clusters = simple_cluster(glycan_list, args.threshold)
splits = assign_splits(clusters, seed=args.seed)
ts = set(splits["train_glycans"]); vs = set(splits["val_glycans"]); es = set(splits["test_glycans"])
assert not (ts & vs) and not (ts & es) and not (vs & es), "Overlap!"
assert len(ts|vs|es) == len(glycan_list), "Missing glycans!"
print(f" No overlap, all {len(glycan_list)} assigned")
with open(args.output, "w") as f:
json.dump(splits, f, indent=2)
print(f" Saved to {args.output}")
if __name__ == "__main__":
main()