File size: 11,927 Bytes
1d7c63d
 
 
 
679611d
 
af9c1e6
86104a0
 
1d7c63d
 
8dc677a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
679611d
 
 
 
 
 
 
 
1d7c63d
1662a5d
1d7c63d
1662a5d
af9c1e6
 
 
 
 
 
 
 
1d7c63d
 
 
2a1e3a6
1330097
b65b9a8
60841b3
1d7c63d
 
 
 
 
 
 
af9c1e6
 
 
 
 
 
 
 
1d7c63d
5a566ad
 
 
8dc677a
 
5a566ad
 
 
8dc677a
 
1662a5d
5a566ad
 
 
8dc677a
1d7c63d
 
86104a0
1d7c63d
 
 
 
 
af9c1e6
 
 
 
c5343e6
 
8dc677a
af9c1e6
 
8dc677a
c5343e6
 
 
 
 
 
8dc677a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5343e6
8dc677a
 
 
 
 
c5343e6
8dc677a
c5343e6
 
af9c1e6
 
c5343e6
8dc677a
86104a0
5a566ad
 
 
8dc677a
 
5a566ad
 
 
8dc677a
 
1662a5d
5a566ad
 
 
8dc677a
86104a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503ec98
86104a0
7a2259a
86104a0
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
from sklearn.decomposition import PCA
import pickle as pk
import numpy as np
import pandas as pd
import os
from huggingface_hub import snapshot_download
import requests
import matplotlib.pyplot as plt
from collections import Counter


def is_git_lfs_pointer(filepath):
    """Check if a file is a Git LFS pointer file instead of actual binary data."""
    try:
        with open(filepath, 'r') as f:
            first_line = f.readline().strip()
            return first_line == 'version https://git-lfs.github.com/spec/v1'
    except:
        return False


def load_pickle_safe(filepath):
    """Safely load a pickle file, checking if it's a Git LFS pointer."""
    if not os.path.exists(filepath):
        raise FileNotFoundError(f"Pickle file not found: {filepath}")
    
    if is_git_lfs_pointer(filepath):
        print(f"Warning: {filepath} is a Git LFS pointer file. Attempting to download actual file...")
        import subprocess
        import sys
        
        # Try to download using git lfs pull
        try:
            # Get the directory of the file
            file_dir = os.path.dirname(os.path.abspath(filepath))
            file_name = os.path.basename(filepath)
            
            # Try git lfs pull in the file's directory
            result = subprocess.run(
                ['git', 'lfs', 'pull', '--include', file_name],
                cwd=file_dir if file_dir else '.',
                capture_output=True,
                text=True,
                timeout=60
            )
            
            if result.returncode == 0:
                print(f"Successfully downloaded {filepath} from Git LFS")
                # Check again if it's still a pointer
                if is_git_lfs_pointer(filepath):
                    raise ValueError(
                        f"Error: {filepath} is still a Git LFS pointer after pull attempt.\n"
                        f"Please ensure Git LFS is properly configured:\n"
                        f"  1. Run 'git lfs install'\n"
                        f"  2. Run 'git lfs pull' in the repository root"
                    )
            else:
                raise ValueError(
                    f"Error: {filepath} is a Git LFS pointer file.\n"
                    f"Failed to download using 'git lfs pull': {result.stderr}\n"
                    f"Please manually download the file:\n"
                    f"  cd {file_dir if file_dir else '.'}\n"
                    f"  git lfs pull --include {file_name}"
                )
        except subprocess.TimeoutExpired:
            raise ValueError(
                f"Timeout while trying to download {filepath} from Git LFS.\n"
                f"Please manually run: git lfs pull"
            )
        except FileNotFoundError:
            raise ValueError(
                f"Error: {filepath} is a Git LFS pointer file, but 'git' command not found.\n"
                f"Please install Git LFS and run 'git lfs pull'"
            )
        except Exception as e:
            raise ValueError(
                f"Error: {filepath} is a Git LFS pointer file.\n"
                f"Failed to download automatically: {e}\n"
                f"Please manually run: git lfs pull"
            )
    
    try:
        with open(filepath, 'rb') as f:
            return pk.load(f)
    except pk.UnpicklingError as e:
        raise ValueError(
            f"Error loading pickle file {filepath}: {e}\n"
            f"This might be a Git LFS pointer file. Please ensure Git LFS is installed and run 'git lfs pull'."
        )


if not os.path.exists('dataset'):
  REPO_ID='Serrelab/Fossils'
  token = os.environ.get('READ_TOKEN')
  print(f"Read token:{token}")
  if token is None:
     print("warning! A read token in env variables is needed for authentication.")
  snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset')


fossils_pd= pd.read_csv('all_fossils_filtered_100.csv')

def pca_distance(pca,sample,embedding,top_k):
    """
    Args:
        pca:fitted PCA model
        sample:sample for which to find the closest embeddings
        embedding:embeddings of the dataset
    Returns:
        The indices of the five closest embeddings to the sample
    """
    s = pca.transform(sample.reshape(1,-1))
    all = pca.transform(embedding[:,-1])
    distances = np.linalg.norm(all - s, axis=1)
    sorted_indices = np.argsort(distances)
    filtered_indices = sorted_indices[sorted_indices<=2852] # exclude general fossils, keep florissant only.
    top_indices = filtered_indices[:top_k+1] #np.concatenate([filtered_indices[:2], filtered_indices[3:top_k+1]])
    return top_indices

def return_paths(argsorted,files):
    paths= []
    for i in argsorted:
        paths.append(files[i])
    return paths

def download_public_image(url, destination_path):
    response = requests.get(url)
    if response.status_code == 200:
        with open(destination_path, 'wb') as f:
            f.write(response.content)
        print(f"Downloaded image to {destination_path}")
    else:
        print(f"Failed to download image from bucket. Status code: {response.status_code}")

def get_images(embedding,model_name):

    if model_name in ['Rock 170','Mummified 170']:
        pca_fossils = load_pickle_safe('pca_fossils_170_finer.pkl')
        pca_leaves = load_pickle_safe('pca_leaves_170_finer.pkl')
        embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
        #embedding_leaves = np.load('embedding_leaves.npy')
    elif model_name in ['Fossils 142']:
        pca_fossils = load_pickle_safe('pca_fossils_142_resnet.pkl')
        pca_leaves = load_pickle_safe('pca_leaves_142_resnet.pkl')
        embedding_fossils = np.load('dataset/embedding_fossils_142_finer.npy')
        #embedding_leaves = np.load('embedding_leaves.npy')
    else:
        print(f'{model_name} not recognized')
        raise ValueError(f'{model_name} not recognized')
    #pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
    
    pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=5)
    
    fossils_paths = fossils_pd['file_name'].values
    
    paths = return_paths(pca_d,fossils_paths)
    print(paths)

    folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/'
    folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/'
    
    local_paths = []
    classes = []
    filenames = []
    for i, path in enumerate(paths):
        local_file_path = f'image_{i}.jpg'
        public_path = None
        if 'Florissant_Fossil/512/full/jpg/' in path:
            public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant)
        elif 'General_Fossil/512/full/jpg/' in path:
            public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general)
        else:
            print("no match found")
            filenames.append("")  # Empty filename if no match
            classes.append("Unknown")
            local_paths.append("")
            continue
        
        # Extract the full specimen name from the original path using split('/')[-1]
        # e.g., /gpfs/.../Fabaceae/Fabaceae_Robinia_lesquereuxi_Florissant_FLFO_002604B.jpg
        # -> Fabaceae_Robinia_lesquereuxi_Florissant_FLFO_002604B.jpg (then remove extension)
        import os
        original_filename = path.split('/')[-1]  # Get the last part of the path (filename)
        full_specimen_name = os.path.splitext(original_filename)[0]  # Remove extension
        
        print(f"Original path: {path}")
        print(f"Full specimen name: {full_specimen_name}")
        print(f"Public path: {public_path}")
        download_public_image(public_path, local_file_path)
        
        # Use the full specimen name from the original filename
        filenames.append(full_specimen_name)
        
        # Extract plant family from public_path for classes
        parts = [part for part in public_path.split('/') if part]
        part = parts[-2]  # Plant family is the folder name in the URL
        classes.append(part)
        local_paths.append(local_file_path)
    #paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/',
    #                     '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]

    return classes, local_paths, filenames

def get_diagram(embedding,top_k,model_name):

    if model_name in ['Rock 170','Mummified 170']:
        pca_fossils = load_pickle_safe('pca_fossils_170_finer.pkl')
        pca_leaves = load_pickle_safe('pca_leaves_170_finer.pkl')
        embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
        #embedding_leaves = np.load('embedding_leaves.npy')
    elif model_name in ['Fossils 142']:
        pca_fossils = load_pickle_safe('pca_fossils_142_resnet.pkl')
        pca_leaves = load_pickle_safe('pca_leaves_142_resnet.pkl')
        embedding_fossils = np.load('dataset/embedding_fossils_142_finer.npy')
        #embedding_leaves = np.load('embedding_leaves.npy')
    else:
        print(f'{model_name} not recognized')
        raise ValueError(f'{model_name} not recognized')
    
    #pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
    
    pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=top_k)
    
    fossils_paths = fossils_pd['file_name'].values
    
    paths = return_paths(pca_d,fossils_paths)
    #print(paths)

    folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/'
    folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/'
    
    classes = []
    for i, path in enumerate(paths):
        local_file_path = f'image_{i}.jpg'
        if 'Florissant_Fossil/512/full/jpg/' in path:
            public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant)
        elif 'General_Fossil/512/full/jpg/' in path:
            public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general)
        else:
            print("no match found")
        print(public_path)
        #download_public_image(public_path, local_file_path)
        parts = [part for part in public_path.split('/') if part]
        part = parts[-2]
        classes.append(part)
        #local_paths.append(local_file_path)
    #paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/',
    #                     '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]
    class_counts = Counter(classes)

    sorted_class_counts = sorted(class_counts.items(), key=lambda item: item[1], reverse=True)
    sorted_classes, sorted_frequencies = zip(*sorted_class_counts)
    colors = plt.cm.viridis(np.linspace(0, 1, len(sorted_classes)))
    fig, ax = plt.subplots()
    ax.bar(sorted_classes, sorted_frequencies,color=colors)
    ax.set_xlabel('Plant Family')
    ax.set_ylabel('Frequency')
    ax.set_title('Distribution of Plant Family of '+str(top_k) +' Closest Samples')
    ax.set_xticklabels(class_counts.keys(), rotation=45, ha='right')

    # Save the diagram to a file
    diagram_path = 'class_distribution_chart.png'
    plt.tight_layout()  # Adjust layout to make room for rotated x-axis labels
    plt.savefig(diagram_path)
    plt.close()  # Close the figure to free up memory

    return diagram_path