fossil_app / closest_sample.py
piperod91's picture
Update closest images display: convert gallery to table format with full specimen names
8dc677a
from sklearn.decomposition import PCA
import pickle as pk
import numpy as np
import pandas as pd
import os
from huggingface_hub import snapshot_download
import requests
import matplotlib.pyplot as plt
from collections import Counter
def is_git_lfs_pointer(filepath):
"""Check if a file is a Git LFS pointer file instead of actual binary data."""
try:
with open(filepath, 'r') as f:
first_line = f.readline().strip()
return first_line == 'version https://git-lfs.github.com/spec/v1'
except:
return False
def load_pickle_safe(filepath):
"""Safely load a pickle file, checking if it's a Git LFS pointer."""
if not os.path.exists(filepath):
raise FileNotFoundError(f"Pickle file not found: {filepath}")
if is_git_lfs_pointer(filepath):
print(f"Warning: {filepath} is a Git LFS pointer file. Attempting to download actual file...")
import subprocess
import sys
# Try to download using git lfs pull
try:
# Get the directory of the file
file_dir = os.path.dirname(os.path.abspath(filepath))
file_name = os.path.basename(filepath)
# Try git lfs pull in the file's directory
result = subprocess.run(
['git', 'lfs', 'pull', '--include', file_name],
cwd=file_dir if file_dir else '.',
capture_output=True,
text=True,
timeout=60
)
if result.returncode == 0:
print(f"Successfully downloaded {filepath} from Git LFS")
# Check again if it's still a pointer
if is_git_lfs_pointer(filepath):
raise ValueError(
f"Error: {filepath} is still a Git LFS pointer after pull attempt.\n"
f"Please ensure Git LFS is properly configured:\n"
f" 1. Run 'git lfs install'\n"
f" 2. Run 'git lfs pull' in the repository root"
)
else:
raise ValueError(
f"Error: {filepath} is a Git LFS pointer file.\n"
f"Failed to download using 'git lfs pull': {result.stderr}\n"
f"Please manually download the file:\n"
f" cd {file_dir if file_dir else '.'}\n"
f" git lfs pull --include {file_name}"
)
except subprocess.TimeoutExpired:
raise ValueError(
f"Timeout while trying to download {filepath} from Git LFS.\n"
f"Please manually run: git lfs pull"
)
except FileNotFoundError:
raise ValueError(
f"Error: {filepath} is a Git LFS pointer file, but 'git' command not found.\n"
f"Please install Git LFS and run 'git lfs pull'"
)
except Exception as e:
raise ValueError(
f"Error: {filepath} is a Git LFS pointer file.\n"
f"Failed to download automatically: {e}\n"
f"Please manually run: git lfs pull"
)
try:
with open(filepath, 'rb') as f:
return pk.load(f)
except pk.UnpicklingError as e:
raise ValueError(
f"Error loading pickle file {filepath}: {e}\n"
f"This might be a Git LFS pointer file. Please ensure Git LFS is installed and run 'git lfs pull'."
)
if not os.path.exists('dataset'):
REPO_ID='Serrelab/Fossils'
token = os.environ.get('READ_TOKEN')
print(f"Read token:{token}")
if token is None:
print("warning! A read token in env variables is needed for authentication.")
snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset')
fossils_pd= pd.read_csv('all_fossils_filtered_100.csv')
def pca_distance(pca,sample,embedding,top_k):
"""
Args:
pca:fitted PCA model
sample:sample for which to find the closest embeddings
embedding:embeddings of the dataset
Returns:
The indices of the five closest embeddings to the sample
"""
s = pca.transform(sample.reshape(1,-1))
all = pca.transform(embedding[:,-1])
distances = np.linalg.norm(all - s, axis=1)
sorted_indices = np.argsort(distances)
filtered_indices = sorted_indices[sorted_indices<=2852] # exclude general fossils, keep florissant only.
top_indices = filtered_indices[:top_k+1] #np.concatenate([filtered_indices[:2], filtered_indices[3:top_k+1]])
return top_indices
def return_paths(argsorted,files):
paths= []
for i in argsorted:
paths.append(files[i])
return paths
def download_public_image(url, destination_path):
response = requests.get(url)
if response.status_code == 200:
with open(destination_path, 'wb') as f:
f.write(response.content)
print(f"Downloaded image to {destination_path}")
else:
print(f"Failed to download image from bucket. Status code: {response.status_code}")
def get_images(embedding,model_name):
if model_name in ['Rock 170','Mummified 170']:
pca_fossils = load_pickle_safe('pca_fossils_170_finer.pkl')
pca_leaves = load_pickle_safe('pca_leaves_170_finer.pkl')
embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
#embedding_leaves = np.load('embedding_leaves.npy')
elif model_name in ['Fossils 142']:
pca_fossils = load_pickle_safe('pca_fossils_142_resnet.pkl')
pca_leaves = load_pickle_safe('pca_leaves_142_resnet.pkl')
embedding_fossils = np.load('dataset/embedding_fossils_142_finer.npy')
#embedding_leaves = np.load('embedding_leaves.npy')
else:
print(f'{model_name} not recognized')
raise ValueError(f'{model_name} not recognized')
#pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=5)
fossils_paths = fossils_pd['file_name'].values
paths = return_paths(pca_d,fossils_paths)
print(paths)
folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/'
folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/'
local_paths = []
classes = []
filenames = []
for i, path in enumerate(paths):
local_file_path = f'image_{i}.jpg'
public_path = None
if 'Florissant_Fossil/512/full/jpg/' in path:
public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant)
elif 'General_Fossil/512/full/jpg/' in path:
public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general)
else:
print("no match found")
filenames.append("") # Empty filename if no match
classes.append("Unknown")
local_paths.append("")
continue
# Extract the full specimen name from the original path using split('/')[-1]
# e.g., /gpfs/.../Fabaceae/Fabaceae_Robinia_lesquereuxi_Florissant_FLFO_002604B.jpg
# -> Fabaceae_Robinia_lesquereuxi_Florissant_FLFO_002604B.jpg (then remove extension)
import os
original_filename = path.split('/')[-1] # Get the last part of the path (filename)
full_specimen_name = os.path.splitext(original_filename)[0] # Remove extension
print(f"Original path: {path}")
print(f"Full specimen name: {full_specimen_name}")
print(f"Public path: {public_path}")
download_public_image(public_path, local_file_path)
# Use the full specimen name from the original filename
filenames.append(full_specimen_name)
# Extract plant family from public_path for classes
parts = [part for part in public_path.split('/') if part]
part = parts[-2] # Plant family is the folder name in the URL
classes.append(part)
local_paths.append(local_file_path)
#paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/',
# '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]
return classes, local_paths, filenames
def get_diagram(embedding,top_k,model_name):
if model_name in ['Rock 170','Mummified 170']:
pca_fossils = load_pickle_safe('pca_fossils_170_finer.pkl')
pca_leaves = load_pickle_safe('pca_leaves_170_finer.pkl')
embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
#embedding_leaves = np.load('embedding_leaves.npy')
elif model_name in ['Fossils 142']:
pca_fossils = load_pickle_safe('pca_fossils_142_resnet.pkl')
pca_leaves = load_pickle_safe('pca_leaves_142_resnet.pkl')
embedding_fossils = np.load('dataset/embedding_fossils_142_finer.npy')
#embedding_leaves = np.load('embedding_leaves.npy')
else:
print(f'{model_name} not recognized')
raise ValueError(f'{model_name} not recognized')
#pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=top_k)
fossils_paths = fossils_pd['file_name'].values
paths = return_paths(pca_d,fossils_paths)
#print(paths)
folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/'
folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/'
classes = []
for i, path in enumerate(paths):
local_file_path = f'image_{i}.jpg'
if 'Florissant_Fossil/512/full/jpg/' in path:
public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant)
elif 'General_Fossil/512/full/jpg/' in path:
public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general)
else:
print("no match found")
print(public_path)
#download_public_image(public_path, local_file_path)
parts = [part for part in public_path.split('/') if part]
part = parts[-2]
classes.append(part)
#local_paths.append(local_file_path)
#paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/',
# '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]
class_counts = Counter(classes)
sorted_class_counts = sorted(class_counts.items(), key=lambda item: item[1], reverse=True)
sorted_classes, sorted_frequencies = zip(*sorted_class_counts)
colors = plt.cm.viridis(np.linspace(0, 1, len(sorted_classes)))
fig, ax = plt.subplots()
ax.bar(sorted_classes, sorted_frequencies,color=colors)
ax.set_xlabel('Plant Family')
ax.set_ylabel('Frequency')
ax.set_title('Distribution of Plant Family of '+str(top_k) +' Closest Samples')
ax.set_xticklabels(class_counts.keys(), rotation=45, ha='right')
# Save the diagram to a file
diagram_path = 'class_distribution_chart.png'
plt.tight_layout() # Adjust layout to make room for rotated x-axis labels
plt.savefig(diagram_path)
plt.close() # Close the figure to free up memory
return diagram_path