| | import sys |
| | sys.path.append("..") |
| | from dataset import load_data |
| | from Bio.pairwise2 import align |
| |
|
| |
|
| | def calculate_similarity_score(seq_gen, seq_tem): |
| | """ |
| | 对两个序列执行全局比对,并返回归一化得分(得分/序列长度) |
| | """ |
| | score = align.globalxx(seq_gen, seq_tem)[0].score |
| | return score |
| | def read_data(): |
| | data = load_data("../dataset/train.xlsx") |
| | for k, v in data.items(): |
| | data[k] = [i for i in v.keys() if i is not k] |
| | return data |
| |
|
| |
|
| | |
| | import os |
| | from collections import Counter, defaultdict |
| | from joblib import Parallel, delayed |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| |
|
| | def extract_templates_and_variants(data_dict): |
| | """ |
| | data_dict: {template_L: [variant_with_D1, variant_with_D2, ...]} |
| | Returns: |
| | templates: list of template sequences (uppercase L-type) |
| | variants_map: dict template -> list of variant sequences (may contain lowercase d-residues) |
| | """ |
| | templates = list(data_dict.keys()) |
| | variants_map = data_dict |
| | return templates, variants_map |
| |
|
| | def compute_length_distribution(templates): |
| | lengths = [len(t) for t in templates] |
| | return lengths |
| |
|
| | def compute_amino_acid_composition(templates): |
| | """ |
| | Returns: |
| | aa_list_sorted: list of amino acids sorted by frequency desc |
| | freqs_sorted: corresponding normalized frequencies |
| | Note: only uppercase letters are expected in templates. |
| | """ |
| | counter = Counter() |
| | total = 0 |
| | for t in templates: |
| | counter.update(t) |
| | total += len(t) |
| | if total == 0: |
| | return [], [] |
| |
|
| | |
| | aa_freq = {aa: cnt / total for aa, cnt in counter.items()} |
| |
|
| | |
| | aa_list_sorted = sorted(aa_freq.keys(), key=lambda x: (-aa_freq[x], x)) |
| | freqs_sorted = [aa_freq[aa] for aa in aa_list_sorted] |
| | return aa_list_sorted, freqs_sorted |
| |
|
| | def compute_d_substitution_counts(templates, variants_map): |
| | """ |
| | For each variant of a template, count number of D-type residues (lowercase letters). |
| | Returns: |
| | d_counts: list of int counts across all variants of all templates. |
| | """ |
| | d_counts = [] |
| | for tpl in templates: |
| | variants = variants_map.get(tpl, []) |
| | for v in variants: |
| | d_counts.append(sum(1 for c in v if c.islower())) |
| | return d_counts |
| |
|
| | def compute_intra_dataset_diversity(templates, n_jobs=-1): |
| | """ |
| | For each template, compute the maximum similarity to any other template. |
| | Similarity uses calculate_similarity(seq_gen, seq_tem) from dataset_more.py, |
| | which normalizes by len(seq_gen). Here we follow your definition and |
| | treat seq_gen as the query template and seq_tem as the other template. |
| | Returns: |
| | max_sims: list of floats, one per template. |
| | |
| | Memory-efficient version with normalization by query length. |
| | Maintains consistency with original calculate_similarity behavior. |
| | """ |
| | n = len(templates) |
| | if n <= 1: |
| | return [1.0] * n |
| | |
| | |
| | pairs = [(i, j) for i in range(n) for j in range(i + 1, n)] |
| | |
| | |
| | similarities = Parallel(n_jobs=n_jobs)( |
| | delayed(calculate_similarity_score)(templates[i], templates[j]) |
| | for i, j in pairs |
| | ) |
| | |
| | |
| | |
| | max_sims = [-np.inf] * n |
| | |
| | for (i, j), sim_score in zip(pairs, similarities): |
| | |
| | sim_i_query = sim_score / len(templates[i]) |
| | max_sims[i] = max(max_sims[i], sim_i_query) |
| | |
| | |
| | sim_j_query = sim_score / len(templates[j]) |
| | max_sims[j] = max(max_sims[j], sim_j_query) |
| | |
| | return max_sims |
| |
|
| | def choose_bins(data, kind="auto", min_bins=10, max_bins=50): |
| | """ |
| | Robust bin selection for histograms. |
| | kind="auto" uses numpy's auto strategy but clip within [min_bins, max_bins]. |
| | """ |
| | if not data: |
| | return 10 |
| | n = len(data) |
| | |
| | data_arr = np.asarray(data) |
| | q75, q25 = np.percentile(data_arr, [75, 25]) |
| | iqr = max(q75 - q25, 1e-9) |
| | bin_width = 2 * iqr / (len(data_arr) ** (1 / 3)) |
| | if bin_width <= 0: |
| | bins = min(max_bins, max(min_bins, int(np.sqrt(n)))) |
| | else: |
| | bins = int(np.ceil((data_arr.max() - data_arr.min()) / bin_width)) |
| | bins = max(min_bins, min(max_bins, bins)) |
| | return bins |
| |
|
| | def plot_distributions(templates, variants_map, figsize=(10, 12), save_path=None, show=True): |
| | |
| | lengths = compute_length_distribution(templates) |
| | aa_list, aa_freqs = compute_amino_acid_composition(templates) |
| | d_counts = compute_d_substitution_counts(templates, variants_map) |
| | max_sims = compute_intra_dataset_diversity(templates) |
| |
|
| | |
| | plt.figure(figsize=figsize) |
| | plt.style.use('seaborn-v0_8-whitegrid') |
| | plt.subplots_adjust(hspace=0.5) |
| |
|
| | |
| | plt.subplot(2, 2, 1) |
| | if lengths: |
| | bins_len = choose_bins(lengths) |
| | plt.hist(lengths, bins=bins_len, color="#4C78A8", edgecolor="white") |
| | else: |
| | plt.hist([], bins=10, color="#4C78A8", edgecolor="white") |
| | plt.title("Length Distribution", fontsize=16, fontweight='bold') |
| | plt.xlabel("Length", fontsize=14) |
| | plt.ylabel("Count", fontsize=14) |
| |
|
| | |
| | plt.subplot(2, 2, 2) |
| | if aa_list: |
| | |
| | x = np.arange(len(aa_list)) |
| | plt.bar(x, aa_freqs, color="#F58518", edgecolor="white") |
| | plt.xticks(x, aa_list, fontsize=9) |
| | else: |
| | plt.bar([], []) |
| | plt.title("Amino Acid Composition", fontsize=16, fontweight='bold') |
| | plt.xlabel("Amino Acid", fontsize=14) |
| | plt.ylabel("Frequency", fontsize=14) |
| |
|
| | |
| | plt.subplot(2, 2, 3) |
| | if d_counts: |
| | bins_d = np.arange(0, max(d_counts) + 2) - 0.5 |
| | plt.hist(d_counts, bins=bins_d, color="#E45756", edgecolor="white") |
| | plt.xticks(range(0, max(d_counts) + 1, 2)) |
| | else: |
| | plt.hist([], bins=np.arange(-0.5, 1.5), color="#E45756", edgecolor="white") |
| | plt.xticks([0]) |
| | plt.xlabel("Number of D substitutions per parent", fontsize=14) |
| | plt.ylabel("Count", fontsize=14) |
| | plt.title("D-type Substitution Count", fontsize=16, fontweight='bold') |
| |
|
| | |
| | plt.subplot(2, 2, 4) |
| | if max_sims: |
| | bins_sim = min(40, max(10, int(len(max_sims) ** 0.5))) |
| | |
| | plt.hist(max_sims, bins=bins_sim, color="#72B7B2", edgecolor="white") |
| | plt.xlim(min(max_sims) - 0.05, max(max_sims) + 0.05) |
| | else: |
| | plt.hist([], bins=10, color="#72B7B2", edgecolor="white") |
| | plt.xlabel("Max similarity to other templates", fontsize=14) |
| | plt.ylabel("Count", fontsize=14) |
| | plt.title("Intra-dataset Diversity", fontsize=16, fontweight='bold') |
| |
|
| | if save_path: |
| | plt.savefig(save_path, bbox_inches="tight") |
| | if show: |
| | plt.show() |
| | plt.close() |
| |
|
| | def main(): |
| | |
| | data = read_data() |
| | templates, variants_map = extract_templates_and_variants(data) |
| | plot_distributions(templates, variants_map, figsize=(14, 7), save_path="dataset_more.svg", show=True) |
| |
|
| | if __name__ == "__main__": |
| | main() |