Spaces:
Running
Running
| from sklearn.decomposition import PCA | |
| import pickle as pk | |
| import numpy as np | |
| import pandas as pd | |
| import os | |
| from huggingface_hub import snapshot_download | |
| import requests | |
| import matplotlib.pyplot as plt | |
| from collections import Counter | |
| def is_git_lfs_pointer(filepath): | |
| """Check if a file is a Git LFS pointer file instead of actual binary data.""" | |
| try: | |
| with open(filepath, 'r') as f: | |
| first_line = f.readline().strip() | |
| return first_line == 'version https://git-lfs.github.com/spec/v1' | |
| except: | |
| return False | |
| def load_pickle_safe(filepath): | |
| """Safely load a pickle file, checking if it's a Git LFS pointer.""" | |
| if not os.path.exists(filepath): | |
| raise FileNotFoundError(f"Pickle file not found: {filepath}") | |
| if is_git_lfs_pointer(filepath): | |
| print(f"Warning: {filepath} is a Git LFS pointer file. Attempting to download actual file...") | |
| import subprocess | |
| import sys | |
| # Try to download using git lfs pull | |
| try: | |
| # Get the directory of the file | |
| file_dir = os.path.dirname(os.path.abspath(filepath)) | |
| file_name = os.path.basename(filepath) | |
| # Try git lfs pull in the file's directory | |
| result = subprocess.run( | |
| ['git', 'lfs', 'pull', '--include', file_name], | |
| cwd=file_dir if file_dir else '.', | |
| capture_output=True, | |
| text=True, | |
| timeout=60 | |
| ) | |
| if result.returncode == 0: | |
| print(f"Successfully downloaded {filepath} from Git LFS") | |
| # Check again if it's still a pointer | |
| if is_git_lfs_pointer(filepath): | |
| raise ValueError( | |
| f"Error: {filepath} is still a Git LFS pointer after pull attempt.\n" | |
| f"Please ensure Git LFS is properly configured:\n" | |
| f" 1. Run 'git lfs install'\n" | |
| f" 2. Run 'git lfs pull' in the repository root" | |
| ) | |
| else: | |
| raise ValueError( | |
| f"Error: {filepath} is a Git LFS pointer file.\n" | |
| f"Failed to download using 'git lfs pull': {result.stderr}\n" | |
| f"Please manually download the file:\n" | |
| f" cd {file_dir if file_dir else '.'}\n" | |
| f" git lfs pull --include {file_name}" | |
| ) | |
| except subprocess.TimeoutExpired: | |
| raise ValueError( | |
| f"Timeout while trying to download {filepath} from Git LFS.\n" | |
| f"Please manually run: git lfs pull" | |
| ) | |
| except FileNotFoundError: | |
| raise ValueError( | |
| f"Error: {filepath} is a Git LFS pointer file, but 'git' command not found.\n" | |
| f"Please install Git LFS and run 'git lfs pull'" | |
| ) | |
| except Exception as e: | |
| raise ValueError( | |
| f"Error: {filepath} is a Git LFS pointer file.\n" | |
| f"Failed to download automatically: {e}\n" | |
| f"Please manually run: git lfs pull" | |
| ) | |
| try: | |
| with open(filepath, 'rb') as f: | |
| return pk.load(f) | |
| except pk.UnpicklingError as e: | |
| raise ValueError( | |
| f"Error loading pickle file {filepath}: {e}\n" | |
| f"This might be a Git LFS pointer file. Please ensure Git LFS is installed and run 'git lfs pull'." | |
| ) | |
| if not os.path.exists('dataset'): | |
| REPO_ID='Serrelab/Fossils' | |
| token = os.environ.get('READ_TOKEN') | |
| print(f"Read token:{token}") | |
| if token is None: | |
| print("warning! A read token in env variables is needed for authentication.") | |
| snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset') | |
| fossils_pd= pd.read_csv('all_fossils_filtered_100.csv') | |
| def pca_distance(pca,sample,embedding,top_k): | |
| """ | |
| Args: | |
| pca:fitted PCA model | |
| sample:sample for which to find the closest embeddings | |
| embedding:embeddings of the dataset | |
| Returns: | |
| The indices of the five closest embeddings to the sample | |
| """ | |
| s = pca.transform(sample.reshape(1,-1)) | |
| all = pca.transform(embedding[:,-1]) | |
| distances = np.linalg.norm(all - s, axis=1) | |
| sorted_indices = np.argsort(distances) | |
| filtered_indices = sorted_indices[sorted_indices<=2852] # exclude general fossils, keep florissant only. | |
| top_indices = filtered_indices[:top_k+1] #np.concatenate([filtered_indices[:2], filtered_indices[3:top_k+1]]) | |
| return top_indices | |
| def return_paths(argsorted,files): | |
| paths= [] | |
| for i in argsorted: | |
| paths.append(files[i]) | |
| return paths | |
| def download_public_image(url, destination_path): | |
| response = requests.get(url) | |
| if response.status_code == 200: | |
| with open(destination_path, 'wb') as f: | |
| f.write(response.content) | |
| print(f"Downloaded image to {destination_path}") | |
| else: | |
| print(f"Failed to download image from bucket. Status code: {response.status_code}") | |
| def get_images(embedding,model_name): | |
| if model_name in ['Rock 170','Mummified 170']: | |
| pca_fossils = load_pickle_safe('pca_fossils_170_finer.pkl') | |
| pca_leaves = load_pickle_safe('pca_leaves_170_finer.pkl') | |
| embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy') | |
| #embedding_leaves = np.load('embedding_leaves.npy') | |
| elif model_name in ['Fossils 142']: | |
| pca_fossils = load_pickle_safe('pca_fossils_142_resnet.pkl') | |
| pca_leaves = load_pickle_safe('pca_leaves_142_resnet.pkl') | |
| embedding_fossils = np.load('dataset/embedding_fossils_142_finer.npy') | |
| #embedding_leaves = np.load('embedding_leaves.npy') | |
| else: | |
| print(f'{model_name} not recognized') | |
| raise ValueError(f'{model_name} not recognized') | |
| #pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1]) | |
| pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=5) | |
| fossils_paths = fossils_pd['file_name'].values | |
| paths = return_paths(pca_d,fossils_paths) | |
| print(paths) | |
| folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/' | |
| folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/' | |
| local_paths = [] | |
| classes = [] | |
| filenames = [] | |
| for i, path in enumerate(paths): | |
| local_file_path = f'image_{i}.jpg' | |
| public_path = None | |
| if 'Florissant_Fossil/512/full/jpg/' in path: | |
| public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant) | |
| elif 'General_Fossil/512/full/jpg/' in path: | |
| public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general) | |
| else: | |
| print("no match found") | |
| filenames.append("") # Empty filename if no match | |
| classes.append("Unknown") | |
| local_paths.append("") | |
| continue | |
| # Extract the full specimen name from the original path using split('/')[-1] | |
| # e.g., /gpfs/.../Fabaceae/Fabaceae_Robinia_lesquereuxi_Florissant_FLFO_002604B.jpg | |
| # -> Fabaceae_Robinia_lesquereuxi_Florissant_FLFO_002604B.jpg (then remove extension) | |
| import os | |
| original_filename = path.split('/')[-1] # Get the last part of the path (filename) | |
| full_specimen_name = os.path.splitext(original_filename)[0] # Remove extension | |
| print(f"Original path: {path}") | |
| print(f"Full specimen name: {full_specimen_name}") | |
| print(f"Public path: {public_path}") | |
| download_public_image(public_path, local_file_path) | |
| # Use the full specimen name from the original filename | |
| filenames.append(full_specimen_name) | |
| # Extract plant family from public_path for classes | |
| parts = [part for part in public_path.split('/') if part] | |
| part = parts[-2] # Plant family is the folder name in the URL | |
| classes.append(part) | |
| local_paths.append(local_file_path) | |
| #paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', | |
| # '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths] | |
| return classes, local_paths, filenames | |
| def get_diagram(embedding,top_k,model_name): | |
| if model_name in ['Rock 170','Mummified 170']: | |
| pca_fossils = load_pickle_safe('pca_fossils_170_finer.pkl') | |
| pca_leaves = load_pickle_safe('pca_leaves_170_finer.pkl') | |
| embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy') | |
| #embedding_leaves = np.load('embedding_leaves.npy') | |
| elif model_name in ['Fossils 142']: | |
| pca_fossils = load_pickle_safe('pca_fossils_142_resnet.pkl') | |
| pca_leaves = load_pickle_safe('pca_leaves_142_resnet.pkl') | |
| embedding_fossils = np.load('dataset/embedding_fossils_142_finer.npy') | |
| #embedding_leaves = np.load('embedding_leaves.npy') | |
| else: | |
| print(f'{model_name} not recognized') | |
| raise ValueError(f'{model_name} not recognized') | |
| #pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1]) | |
| pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=top_k) | |
| fossils_paths = fossils_pd['file_name'].values | |
| paths = return_paths(pca_d,fossils_paths) | |
| #print(paths) | |
| folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/' | |
| folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/' | |
| classes = [] | |
| for i, path in enumerate(paths): | |
| local_file_path = f'image_{i}.jpg' | |
| if 'Florissant_Fossil/512/full/jpg/' in path: | |
| public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant) | |
| elif 'General_Fossil/512/full/jpg/' in path: | |
| public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general) | |
| else: | |
| print("no match found") | |
| print(public_path) | |
| #download_public_image(public_path, local_file_path) | |
| parts = [part for part in public_path.split('/') if part] | |
| part = parts[-2] | |
| classes.append(part) | |
| #local_paths.append(local_file_path) | |
| #paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', | |
| # '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths] | |
| class_counts = Counter(classes) | |
| sorted_class_counts = sorted(class_counts.items(), key=lambda item: item[1], reverse=True) | |
| sorted_classes, sorted_frequencies = zip(*sorted_class_counts) | |
| colors = plt.cm.viridis(np.linspace(0, 1, len(sorted_classes))) | |
| fig, ax = plt.subplots() | |
| ax.bar(sorted_classes, sorted_frequencies,color=colors) | |
| ax.set_xlabel('Plant Family') | |
| ax.set_ylabel('Frequency') | |
| ax.set_title('Distribution of Plant Family of '+str(top_k) +' Closest Samples') | |
| ax.set_xticklabels(class_counts.keys(), rotation=45, ha='right') | |
| # Save the diagram to a file | |
| diagram_path = 'class_distribution_chart.png' | |
| plt.tight_layout() # Adjust layout to make room for rotated x-axis labels | |
| plt.savefig(diagram_path) | |
| plt.close() # Close the figure to free up memory | |
| return diagram_path |