KasaHealth / utils /test_overlap.py
78anand's picture
Upload folder using huggingface_hub
4fcfef4 verified
import numpy as np
import os
from scipy.spatial.distance import cdist
# Paths
base_dir = r"c:\Users\ASUS\lung_ai_project\data"
path1_x = os.path.join(base_dir, "hear_embeddings_optimized", "X_hear_opt_merged.npy")
path1_y = os.path.join(base_dir, "hear_embeddings_optimized", "y_hear_opt_merged.npy")
path2_x = os.path.join(base_dir, "hear_embeddings_coughvid", "X_coughvid.npy")
path2_y = os.path.join(base_dir, "hear_embeddings_coughvid", "y_coughvid.npy")
# Load and clean
def clean_y(y):
if y.dtype.kind in ['U', 'S']:
return np.where(y == 'sick', 1, 0).astype(np.float32)
return y.astype(np.float32)
X1, y1 = np.load(path1_x), clean_y(np.load(path1_y))
X2, y2 = np.load(path2_x), clean_y(np.load(path2_y))
X = np.concatenate([X1, X2], axis=0).astype(np.float32)
y = np.concatenate([y1, y2], axis=0)
# Randomly sample some sick and healthy to check proximity
sick_indices = np.where(y == 1)[0]
healthy_indices = np.where(y == 0)[0]
# Pick a small Subset to check distances (full 11k cdist is too slow)
subs_s = np.random.choice(sick_indices, 500, replace=False)
subs_h = np.random.choice(healthy_indices, 500, replace=False)
X_s = X[subs_s]
X_h = X[subs_h]
# Check distances between 500 sick and 500 healthy samples
dist_matrix = cdist(X_s, X_h, 'cosine')
# Find how many sick samples are extremely close to healthy ones
very_close = np.where(dist_matrix < 0.05)
print(f"Overlap Analysis (Cosine Distance < 0.05): {len(very_close[0])} pairs found.")
avg_dist_sick_to_healthy = np.mean(dist_matrix)
print(f"Average Distance (Sick to Healthy): {avg_dist_sick_to_healthy:.4f}")
# Check distances within sick
dist_within_sick = cdist(X_s, X_s, 'cosine')
avg_dist_within_sick = np.mean(dist_within_sick)
print(f"Average Distance (Within Sick): {avg_dist_within_sick:.4f}")
# Check distances within healthy
dist_within_healthy = cdist(X_h, X_h, 'cosine')
avg_dist_within_healthy = np.mean(dist_within_healthy)
print(f"Average Distance (Within Healthy): {avg_dist_within_healthy:.4f}")