from sklearn.decomposition import PCA import pickle as pk import numpy as np import pandas as pd import os from huggingface_hub import snapshot_download import requests import matplotlib.pyplot as plt from collections import Counter def is_git_lfs_pointer(filepath): """Check if a file is a Git LFS pointer file instead of actual binary data.""" try: with open(filepath, 'r') as f: first_line = f.readline().strip() return first_line == 'version https://git-lfs.github.com/spec/v1' except: return False def load_pickle_safe(filepath): """Safely load a pickle file, checking if it's a Git LFS pointer.""" if not os.path.exists(filepath): raise FileNotFoundError(f"Pickle file not found: {filepath}") if is_git_lfs_pointer(filepath): print(f"Warning: {filepath} is a Git LFS pointer file. Attempting to download actual file...") import subprocess import sys # Try to download using git lfs pull try: # Get the directory of the file file_dir = os.path.dirname(os.path.abspath(filepath)) file_name = os.path.basename(filepath) # Try git lfs pull in the file's directory result = subprocess.run( ['git', 'lfs', 'pull', '--include', file_name], cwd=file_dir if file_dir else '.', capture_output=True, text=True, timeout=60 ) if result.returncode == 0: print(f"Successfully downloaded {filepath} from Git LFS") # Check again if it's still a pointer if is_git_lfs_pointer(filepath): raise ValueError( f"Error: {filepath} is still a Git LFS pointer after pull attempt.\n" f"Please ensure Git LFS is properly configured:\n" f" 1. Run 'git lfs install'\n" f" 2. Run 'git lfs pull' in the repository root" ) else: raise ValueError( f"Error: {filepath} is a Git LFS pointer file.\n" f"Failed to download using 'git lfs pull': {result.stderr}\n" f"Please manually download the file:\n" f" cd {file_dir if file_dir else '.'}\n" f" git lfs pull --include {file_name}" ) except subprocess.TimeoutExpired: raise ValueError( f"Timeout while trying to download {filepath} from Git LFS.\n" f"Please manually run: git lfs pull" ) except FileNotFoundError: raise ValueError( f"Error: {filepath} is a Git LFS pointer file, but 'git' command not found.\n" f"Please install Git LFS and run 'git lfs pull'" ) except Exception as e: raise ValueError( f"Error: {filepath} is a Git LFS pointer file.\n" f"Failed to download automatically: {e}\n" f"Please manually run: git lfs pull" ) try: with open(filepath, 'rb') as f: return pk.load(f) except pk.UnpicklingError as e: raise ValueError( f"Error loading pickle file {filepath}: {e}\n" f"This might be a Git LFS pointer file. Please ensure Git LFS is installed and run 'git lfs pull'." ) if not os.path.exists('dataset'): REPO_ID='Serrelab/Fossils' token = os.environ.get('READ_TOKEN') print(f"Read token:{token}") if token is None: print("warning! A read token in env variables is needed for authentication.") snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset') fossils_pd= pd.read_csv('all_fossils_filtered_100.csv') def pca_distance(pca,sample,embedding,top_k): """ Args: pca:fitted PCA model sample:sample for which to find the closest embeddings embedding:embeddings of the dataset Returns: The indices of the five closest embeddings to the sample """ s = pca.transform(sample.reshape(1,-1)) all = pca.transform(embedding[:,-1]) distances = np.linalg.norm(all - s, axis=1) sorted_indices = np.argsort(distances) filtered_indices = sorted_indices[sorted_indices<=2852] # exclude general fossils, keep florissant only. top_indices = filtered_indices[:top_k+1] #np.concatenate([filtered_indices[:2], filtered_indices[3:top_k+1]]) return top_indices def return_paths(argsorted,files): paths= [] for i in argsorted: paths.append(files[i]) return paths def download_public_image(url, destination_path): response = requests.get(url) if response.status_code == 200: with open(destination_path, 'wb') as f: f.write(response.content) print(f"Downloaded image to {destination_path}") else: print(f"Failed to download image from bucket. Status code: {response.status_code}") def get_images(embedding,model_name): if model_name in ['Rock 170','Mummified 170']: pca_fossils = load_pickle_safe('pca_fossils_170_finer.pkl') pca_leaves = load_pickle_safe('pca_leaves_170_finer.pkl') embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy') #embedding_leaves = np.load('embedding_leaves.npy') elif model_name in ['Fossils 142']: pca_fossils = load_pickle_safe('pca_fossils_142_resnet.pkl') pca_leaves = load_pickle_safe('pca_leaves_142_resnet.pkl') embedding_fossils = np.load('dataset/embedding_fossils_142_finer.npy') #embedding_leaves = np.load('embedding_leaves.npy') else: print(f'{model_name} not recognized') raise ValueError(f'{model_name} not recognized') #pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1]) pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=5) fossils_paths = fossils_pd['file_name'].values paths = return_paths(pca_d,fossils_paths) print(paths) folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/' folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/' local_paths = [] classes = [] filenames = [] for i, path in enumerate(paths): local_file_path = f'image_{i}.jpg' public_path = None if 'Florissant_Fossil/512/full/jpg/' in path: public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant) elif 'General_Fossil/512/full/jpg/' in path: public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general) else: print("no match found") filenames.append("") # Empty filename if no match classes.append("Unknown") local_paths.append("") continue # Extract the full specimen name from the original path using split('/')[-1] # e.g., /gpfs/.../Fabaceae/Fabaceae_Robinia_lesquereuxi_Florissant_FLFO_002604B.jpg # -> Fabaceae_Robinia_lesquereuxi_Florissant_FLFO_002604B.jpg (then remove extension) import os original_filename = path.split('/')[-1] # Get the last part of the path (filename) full_specimen_name = os.path.splitext(original_filename)[0] # Remove extension print(f"Original path: {path}") print(f"Full specimen name: {full_specimen_name}") print(f"Public path: {public_path}") download_public_image(public_path, local_file_path) # Use the full specimen name from the original filename filenames.append(full_specimen_name) # Extract plant family from public_path for classes parts = [part for part in public_path.split('/') if part] part = parts[-2] # Plant family is the folder name in the URL classes.append(part) local_paths.append(local_file_path) #paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', # '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths] return classes, local_paths, filenames def get_diagram(embedding,top_k,model_name): if model_name in ['Rock 170','Mummified 170']: pca_fossils = load_pickle_safe('pca_fossils_170_finer.pkl') pca_leaves = load_pickle_safe('pca_leaves_170_finer.pkl') embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy') #embedding_leaves = np.load('embedding_leaves.npy') elif model_name in ['Fossils 142']: pca_fossils = load_pickle_safe('pca_fossils_142_resnet.pkl') pca_leaves = load_pickle_safe('pca_leaves_142_resnet.pkl') embedding_fossils = np.load('dataset/embedding_fossils_142_finer.npy') #embedding_leaves = np.load('embedding_leaves.npy') else: print(f'{model_name} not recognized') raise ValueError(f'{model_name} not recognized') #pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1]) pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=top_k) fossils_paths = fossils_pd['file_name'].values paths = return_paths(pca_d,fossils_paths) #print(paths) folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/' folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/' classes = [] for i, path in enumerate(paths): local_file_path = f'image_{i}.jpg' if 'Florissant_Fossil/512/full/jpg/' in path: public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant) elif 'General_Fossil/512/full/jpg/' in path: public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general) else: print("no match found") print(public_path) #download_public_image(public_path, local_file_path) parts = [part for part in public_path.split('/') if part] part = parts[-2] classes.append(part) #local_paths.append(local_file_path) #paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', # '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths] class_counts = Counter(classes) sorted_class_counts = sorted(class_counts.items(), key=lambda item: item[1], reverse=True) sorted_classes, sorted_frequencies = zip(*sorted_class_counts) colors = plt.cm.viridis(np.linspace(0, 1, len(sorted_classes))) fig, ax = plt.subplots() ax.bar(sorted_classes, sorted_frequencies,color=colors) ax.set_xlabel('Plant Family') ax.set_ylabel('Frequency') ax.set_title('Distribution of Plant Family of '+str(top_k) +' Closest Samples') ax.set_xticklabels(class_counts.keys(), rotation=45, ha='right') # Save the diagram to a file diagram_path = 'class_distribution_chart.png' plt.tight_layout() # Adjust layout to make room for rotated x-axis labels plt.savefig(diagram_path) plt.close() # Close the figure to free up memory return diagram_path