rethinks commited on
Commit
1843b5e
·
verified ·
1 Parent(s): 8a9d2af

Upload 4 files

Browse files
app.py CHANGED
@@ -308,6 +308,25 @@ def process_photos_face_filter_only(job_id, upload_dir, session_id=None):
308
  'timestamp': timestamp
309
  })
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  # Sort unmatched by timestamp
312
  unmatched_photos.sort(key=lambda x: x.get('timestamp') or 0)
313
 
@@ -494,6 +513,11 @@ def process_drive_with_parallel_face_detection(job_id, folder_id, upload_dir, fa
494
  print(f" - Photos with your child: {len(matched_photos)}")
495
  print(f" - Photos without match: {len(unmatched_photos)}")
496
  print(f" - Photos with no faces: {len(no_faces_photos)}")
 
 
 
 
 
497
 
498
  # Now create thumbnails and prepare review data
499
  processing_jobs[job_id]['progress'] = 75
@@ -542,6 +566,16 @@ def process_drive_with_parallel_face_detection(job_id, folder_id, upload_dir, fa
542
  'num_faces': 0
543
  })
544
 
 
 
 
 
 
 
 
 
 
 
545
  # Store results
546
  review_data = {
547
  'total_uploaded': total_files[0],
@@ -744,7 +778,7 @@ def save_photos_by_month(job_id, upload_dir, selected_photos, rejected_photos, m
744
  return None
745
 
746
 
747
- def process_photos_quality_selection(job_id, upload_dir, quality_mode, similarity_threshold, confirmed_photos, face_data_cache=None):
748
  """
749
  Phase 2: Month-based category-aware photo selection.
750
  Selects ~40 best photos per month with category diversity.
@@ -752,6 +786,7 @@ def process_photos_quality_selection(job_id, upload_dir, quality_mode, similarit
752
  Args:
753
  face_data_cache: Dict of filename -> {'num_faces': int, 'face_bboxes': list}
754
  Cached face data from Step 2 to avoid re-detection
 
755
  """
756
  face_data_cache = face_data_cache or {}
757
  try:
@@ -761,14 +796,20 @@ def process_photos_quality_selection(job_id, upload_dir, quality_mode, similarit
761
  print(f"[Job {job_id}] Confirmed photos: {len(confirmed_photos)}")
762
  print(f"[Job {job_id}] Quality mode: {quality_mode}")
763
  print(f"[Job {job_id}] Similarity threshold: {similarity_threshold}")
 
764
 
765
  processing_jobs[job_id]['status'] = 'processing'
766
  processing_jobs[job_id]['progress'] = 5
767
- processing_jobs[job_id]['message'] = 'Loading AI models...'
768
 
769
- # Import the new monthly selector
770
- from photo_selector.siglip_embeddings import SigLIPEmbedder
771
  from photo_selector.monthly_selector import MonthlyPhotoSelector
 
 
 
 
 
 
772
 
773
  # Determine target per month based on quality mode
774
  if quality_mode == 'keep_more':
@@ -782,11 +823,11 @@ def process_photos_quality_selection(job_id, upload_dir, quality_mode, similarit
782
 
783
  # Step 1: Generate embeddings for confirmed photos
784
  processing_jobs[job_id]['progress'] = 10
785
- processing_jobs[job_id]['message'] = 'Analyzing photos with SigLIP AI...'
786
 
787
- print(f"[Job {job_id}] Generating SigLIP embeddings for {len(confirmed_photos)} photos...")
788
 
789
- embedder = SigLIPEmbedder()
790
  embeddings = {}
791
 
792
  for i, filename in enumerate(confirmed_photos):
@@ -1844,11 +1885,64 @@ def import_from_drive_reupload(dataset_name):
1844
  print(f"[Job {job_id}] Loaded {len(matcher.reference_embeddings)} reference embeddings")
1845
 
1846
  # Match uploaded files with saved face results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1847
  filtered_photos = face_results.get('filtered_photos', [])
1848
  uploaded_set = set(uploaded_filenames)
1849
- matched_photos = [p for p in filtered_photos if p.get('filename') in uploaded_set]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1850
 
1851
  print(f"[Job {job_id}] Matched {len(matched_photos)} of {len(filtered_photos)} photos")
 
 
 
 
 
 
 
 
 
 
 
1852
 
1853
  # Create review data
1854
  review_data = {
@@ -2677,6 +2771,11 @@ def confirm_selection(job_id):
2677
  if len(confirmed_photos) == 0:
2678
  return jsonify({'error': 'At least one photo must be selected'}), 400
2679
 
 
 
 
 
 
2680
  # Get processing parameters from job
2681
  quality_mode = job.get('quality_mode', 'balanced')
2682
  similarity_threshold = job.get('similarity_threshold', 0.92)
@@ -2717,7 +2816,7 @@ def confirm_selection(job_id):
2717
  # Start phase 2 processing
2718
  thread = threading.Thread(
2719
  target=process_photos_quality_selection,
2720
- args=(job_id, upload_dir, quality_mode, similarity_threshold, confirmed_photos, face_data_cache)
2721
  )
2722
  thread.start()
2723
 
 
308
  'timestamp': timestamp
309
  })
310
 
311
+ # Also include photos that had processing errors
312
+ for error_photo in filter_results.get('error_photos', []):
313
+ filename = os.path.basename(error_photo['path'])
314
+ timestamp = None
315
+ try:
316
+ from photo_selector.utils import get_photo_timestamp
317
+ dt = get_photo_timestamp(error_photo['path'])
318
+ if dt:
319
+ timestamp = dt.timestamp()
320
+ except:
321
+ pass
322
+ unmatched_photos.append({
323
+ 'filename': filename,
324
+ 'best_similarity': 0,
325
+ 'num_faces': 0,
326
+ 'timestamp': timestamp,
327
+ 'error': error_photo.get('error', 'Processing error')
328
+ })
329
+
330
  # Sort unmatched by timestamp
331
  unmatched_photos.sort(key=lambda x: x.get('timestamp') or 0)
332
 
 
513
  print(f" - Photos with your child: {len(matched_photos)}")
514
  print(f" - Photos without match: {len(unmatched_photos)}")
515
  print(f" - Photos with no faces: {len(no_faces_photos)}")
516
+ print(f" - Photos with errors: {len(error_photos)}")
517
+ if error_photos:
518
+ print(f" [ERRORS] First 5 error photos:")
519
+ for ep in error_photos[:5]:
520
+ print(f" - {os.path.basename(ep['path'])}: {ep.get('error', 'Unknown error')}")
521
 
522
  # Now create thumbnails and prepare review data
523
  processing_jobs[job_id]['progress'] = 75
 
566
  'num_faces': 0
567
  })
568
 
569
+ # Also add error photos to unmatched (so they're visible to user)
570
+ for error_photo in error_photos:
571
+ filename = os.path.basename(error_photo['path'])
572
+ unmatched_data.append({
573
+ 'filename': filename,
574
+ 'best_similarity': 0,
575
+ 'num_faces': 0,
576
+ 'error': error_photo.get('error', 'Processing error')
577
+ })
578
+
579
  # Store results
580
  review_data = {
581
  'total_uploaded': total_files[0],
 
778
  return None
779
 
780
 
781
+ def process_photos_quality_selection(job_id, upload_dir, quality_mode, similarity_threshold, confirmed_photos, face_data_cache=None, embedding_model='siglip'):
782
  """
783
  Phase 2: Month-based category-aware photo selection.
784
  Selects ~40 best photos per month with category diversity.
 
786
  Args:
787
  face_data_cache: Dict of filename -> {'num_faces': int, 'face_bboxes': list}
788
  Cached face data from Step 2 to avoid re-detection
789
+ embedding_model: 'siglip' or 'clip' - which embedding model to use
790
  """
791
  face_data_cache = face_data_cache or {}
792
  try:
 
796
  print(f"[Job {job_id}] Confirmed photos: {len(confirmed_photos)}")
797
  print(f"[Job {job_id}] Quality mode: {quality_mode}")
798
  print(f"[Job {job_id}] Similarity threshold: {similarity_threshold}")
799
+ print(f"[Job {job_id}] Embedding model: {embedding_model.upper()}")
800
 
801
  processing_jobs[job_id]['status'] = 'processing'
802
  processing_jobs[job_id]['progress'] = 5
803
+ processing_jobs[job_id]['message'] = f'Loading {embedding_model.upper()} model...'
804
 
805
+ # Import the appropriate embedder based on selection
 
806
  from photo_selector.monthly_selector import MonthlyPhotoSelector
807
+ if embedding_model == 'clip':
808
+ from photo_selector.clip_embeddings import CLIPEmbedder as Embedder
809
+ model_display_name = 'CLIP'
810
+ else:
811
+ from photo_selector.siglip_embeddings import SigLIPEmbedder as Embedder
812
+ model_display_name = 'SigLIP'
813
 
814
  # Determine target per month based on quality mode
815
  if quality_mode == 'keep_more':
 
823
 
824
  # Step 1: Generate embeddings for confirmed photos
825
  processing_jobs[job_id]['progress'] = 10
826
+ processing_jobs[job_id]['message'] = f'Analyzing photos with {model_display_name}...'
827
 
828
+ print(f"[Job {job_id}] Generating {model_display_name} embeddings for {len(confirmed_photos)} photos...")
829
 
830
+ embedder = Embedder()
831
  embeddings = {}
832
 
833
  for i, filename in enumerate(confirmed_photos):
 
1885
  print(f"[Job {job_id}] Loaded {len(matcher.reference_embeddings)} reference embeddings")
1886
 
1887
  # Match uploaded files with saved face results
1888
+ # Google Drive filenames differ from browser upload:
1889
+ # 1. Duplicates: IMG_5197(1).JPG vs IMG_51971.JPG
1890
+ # 2. Spaces: IMG_6970 Copy.JPG vs IMG_6970_Copy.JPG
1891
+ import re
1892
+ def normalize_filename(filename):
1893
+ """Normalize Google Drive filename to match browser upload format."""
1894
+ # Step 1: Convert (N) suffix to N (Google Drive duplicate handling)
1895
+ match = re.match(r'^(.+)\((\d+)\)(\.[^.]+)$', filename)
1896
+ if match:
1897
+ base, num, ext = match.groups()
1898
+ filename = f"{base}{num}{ext}"
1899
+ # Step 2: Apply secure_filename (spaces -> underscores, etc.)
1900
+ return secure_filename(filename)
1901
+
1902
  filtered_photos = face_results.get('filtered_photos', [])
1903
  uploaded_set = set(uploaded_filenames)
1904
+ saved_filenames_set = {p.get('filename') for p in filtered_photos}
1905
+
1906
+ # Create mapping: normalized_name -> actual_uploaded_name
1907
+ normalized_to_uploaded = {normalize_filename(f): f for f in uploaded_filenames}
1908
+
1909
+ matched_photos = []
1910
+ for p in filtered_photos:
1911
+ saved_filename = p.get('filename')
1912
+ actual_filename = None
1913
+
1914
+ # Try direct match first
1915
+ if saved_filename in uploaded_set:
1916
+ actual_filename = saved_filename
1917
+ # Try normalized match (saved name matches normalized uploaded name)
1918
+ elif saved_filename in normalized_to_uploaded:
1919
+ actual_filename = normalized_to_uploaded[saved_filename]
1920
+
1921
+ if actual_filename:
1922
+ # Use actual uploaded filename for the photo entry
1923
+ photo_entry = p.copy()
1924
+ photo_entry['filename'] = actual_filename
1925
+ photo_entry['thumbnail'] = get_thumbnail_name(actual_filename)
1926
+ matched_photos.append(photo_entry)
1927
+
1928
+ # Debug: Find unmatched photos
1929
+ matched_saved = {p.get('filename') for p in filtered_photos if p.get('filename') in uploaded_set or p.get('filename') in normalized_to_uploaded}
1930
+ unmatched_from_saved = [p.get('filename') for p in filtered_photos if p.get('filename') not in matched_saved]
1931
+ matched_uploaded = {m['filename'] for m in matched_photos}
1932
+ unmatched_from_uploaded = [f for f in uploaded_filenames if f not in matched_uploaded]
1933
 
1934
  print(f"[Job {job_id}] Matched {len(matched_photos)} of {len(filtered_photos)} photos")
1935
+ print(f"[Job {job_id}] DEBUG: {len(unmatched_from_saved)} saved photos NOT found in uploaded files:")
1936
+ for fname in unmatched_from_saved[:20]: # Show first 20
1937
+ print(f" [SAVED NOT IN UPLOAD] '{fname}'")
1938
+ if len(unmatched_from_saved) > 20:
1939
+ print(f" ... and {len(unmatched_from_saved) - 20} more")
1940
+
1941
+ print(f"[Job {job_id}] DEBUG: {len(unmatched_from_uploaded)} uploaded files NOT found in saved data:")
1942
+ for fname in unmatched_from_uploaded[:20]: # Show first 20
1943
+ print(f" [UPLOAD NOT IN SAVED] '{fname}'")
1944
+ if len(unmatched_from_uploaded) > 20:
1945
+ print(f" ... and {len(unmatched_from_uploaded) - 20} more")
1946
 
1947
  # Create review data
1948
  review_data = {
 
2771
  if len(confirmed_photos) == 0:
2772
  return jsonify({'error': 'At least one photo must be selected'}), 400
2773
 
2774
+ # Get embedding model selection (default to siglip)
2775
+ embedding_model = data.get('embedding_model', 'siglip')
2776
+ if embedding_model not in ['siglip', 'clip']:
2777
+ embedding_model = 'siglip'
2778
+
2779
  # Get processing parameters from job
2780
  quality_mode = job.get('quality_mode', 'balanced')
2781
  similarity_threshold = job.get('similarity_threshold', 0.92)
 
2816
  # Start phase 2 processing
2817
  thread = threading.Thread(
2818
  target=process_photos_quality_selection,
2819
+ args=(job_id, upload_dir, quality_mode, similarity_threshold, confirmed_photos, face_data_cache, embedding_model)
2820
  )
2821
  thread.start()
2822
 
photo_selector/clip_embeddings.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CLIP embeddings for photo clustering.
3
+ CLIP (Contrastive Language-Image Pre-training) by OpenAI.
4
+
5
+ Uses ViT-B/32 by default (512-dim embeddings)
6
+ """
7
+
8
+ import os
9
+ import numpy as np
10
+ from pathlib import Path
11
+ from PIL import Image
12
+ import torch
13
+ from typing import List, Dict, Tuple, Optional
14
+
15
+ # Try to import CLIP
16
+ try:
17
+ import clip
18
+ CLIP_AVAILABLE = True
19
+ except ImportError:
20
+ CLIP_AVAILABLE = False
21
+ print("CLIP not installed. Run: pip install git+https://github.com/openai/CLIP.git")
22
+
23
+ # HEIC support
24
+ try:
25
+ from pillow_heif import register_heif_opener
26
+ register_heif_opener()
27
+ except ImportError:
28
+ pass
29
+
30
+
31
+ class CLIPEmbedder:
32
+ """Generate CLIP embeddings for photos."""
33
+
34
+ def __init__(self, model_name: str = "ViT-B/32", device: str = None):
35
+ """
36
+ Initialize the CLIP model.
37
+
38
+ Args:
39
+ model_name: CLIP model variant. Options:
40
+ - "ViT-B/32" (512-dim, fastest)
41
+ - "ViT-B/16" (512-dim, better quality)
42
+ - "ViT-L/14" (768-dim, best quality)
43
+ - "ViT-L/14@336px" (768-dim, highest resolution)
44
+ device: 'cuda' or 'cpu', auto-detected if None
45
+ """
46
+ if not CLIP_AVAILABLE:
47
+ raise ImportError("CLIP is required. Install with: pip install git+https://github.com/openai/CLIP.git")
48
+
49
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
50
+ print(f"Loading CLIP model '{model_name}' on {self.device}...")
51
+
52
+ self.model, self.preprocess = clip.load(model_name, device=self.device)
53
+ self.model.eval()
54
+ self.embedding_dim = self.model.visual.output_dim
55
+ self.model_name = model_name
56
+
57
+ print(f"CLIP loaded. Embedding dimension: {self.embedding_dim}")
58
+
59
+ def load_image(self, image_path: str) -> Optional[Image.Image]:
60
+ """Load and preprocess an image."""
61
+ try:
62
+ img = Image.open(image_path)
63
+ # Convert to RGB if necessary
64
+ if img.mode != 'RGB':
65
+ img = img.convert('RGB')
66
+ return img
67
+ except Exception as e:
68
+ print(f"Error loading {image_path}: {e}")
69
+ return None
70
+
71
+ def get_embedding(self, image: Image.Image) -> np.ndarray:
72
+ """Get CLIP embedding for a single image."""
73
+ with torch.no_grad():
74
+ image_input = self.preprocess(image).unsqueeze(0).to(self.device)
75
+ embedding = self.model.encode_image(image_input)
76
+ # Normalize the embedding
77
+ embedding = embedding / embedding.norm(dim=-1, keepdim=True)
78
+ return embedding.cpu().numpy().flatten()
79
+
80
+ def get_embeddings_batch(self, images: List[Image.Image], batch_size: int = 32) -> np.ndarray:
81
+ """Get CLIP embeddings for a batch of images."""
82
+ all_embeddings = []
83
+
84
+ for i in range(0, len(images), batch_size):
85
+ batch_images = images[i:i + batch_size]
86
+
87
+ with torch.no_grad():
88
+ # Preprocess all images in batch
89
+ image_inputs = torch.stack([self.preprocess(img) for img in batch_images]).to(self.device)
90
+ embeddings = self.model.encode_image(image_inputs)
91
+
92
+ # Normalize
93
+ embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
94
+ all_embeddings.append(embeddings.cpu().numpy())
95
+
96
+ return np.vstack(all_embeddings)
97
+
98
+ def process_folder(self, folder_path: str,
99
+ image_extensions: set = None,
100
+ batch_size: int = 32,
101
+ use_batching: bool = True) -> Dict[str, np.ndarray]:
102
+ """
103
+ Process all images in a folder and generate embeddings.
104
+
105
+ Args:
106
+ folder_path: Path to folder containing images
107
+ image_extensions: Set of valid extensions
108
+ batch_size: Number of images to process at once
109
+ use_batching: Whether to use batch processing (faster but more memory)
110
+
111
+ Returns:
112
+ Dictionary mapping filename to embedding
113
+ """
114
+ if image_extensions is None:
115
+ image_extensions = {'.jpg', '.jpeg', '.png', '.heic', '.heif', '.webp'}
116
+
117
+ folder = Path(folder_path)
118
+ image_files = [f for f in folder.iterdir()
119
+ if f.suffix.lower() in image_extensions]
120
+
121
+ print(f"Found {len(image_files)} images in {folder_path}")
122
+
123
+ embeddings = {}
124
+ errors = []
125
+
126
+ if use_batching and len(image_files) > batch_size:
127
+ # Batch processing for efficiency
128
+ print(f"Using batch processing (batch_size={batch_size})...")
129
+
130
+ for batch_start in range(0, len(image_files), batch_size):
131
+ batch_end = min(batch_start + batch_size, len(image_files))
132
+ batch_files = image_files[batch_start:batch_end]
133
+
134
+ print(f"Processing batch [{batch_start+1}-{batch_end}/{len(image_files)}]")
135
+
136
+ batch_images = []
137
+ batch_names = []
138
+
139
+ for image_path in batch_files:
140
+ try:
141
+ img = self.load_image(str(image_path))
142
+ if img is not None:
143
+ batch_images.append(img)
144
+ batch_names.append(image_path.name)
145
+ except Exception as e:
146
+ errors.append((image_path.name, str(e)))
147
+
148
+ if batch_images:
149
+ try:
150
+ batch_embeddings = self.get_embeddings_batch(batch_images)
151
+ for name, emb in zip(batch_names, batch_embeddings):
152
+ embeddings[name] = emb
153
+ except Exception as e:
154
+ print(f"Batch processing failed, falling back to individual: {e}")
155
+ for img, name in zip(batch_images, batch_names):
156
+ try:
157
+ embeddings[name] = self.get_embedding(img)
158
+ except Exception as e2:
159
+ errors.append((name, str(e2)))
160
+
161
+ # Close images
162
+ for img in batch_images:
163
+ img.close()
164
+ else:
165
+ # Individual processing
166
+ for i, image_path in enumerate(image_files):
167
+ if (i + 1) % 10 == 0:
168
+ print(f"Processing [{i+1}/{len(image_files)}] {image_path.name}")
169
+
170
+ try:
171
+ img = self.load_image(str(image_path))
172
+ if img is not None:
173
+ embedding = self.get_embedding(img)
174
+ embeddings[image_path.name] = embedding
175
+ img.close()
176
+ except Exception as e:
177
+ errors.append((image_path.name, str(e)))
178
+
179
+ print(f"\nProcessed {len(embeddings)} images successfully")
180
+ if errors:
181
+ print(f"Errors on {len(errors)} images")
182
+
183
+ return embeddings
184
+
185
+ def save_embeddings(self, embeddings: Dict[str, np.ndarray],
186
+ output_path: str):
187
+ """Save embeddings to a numpy file."""
188
+ data = {
189
+ 'filenames': list(embeddings.keys()),
190
+ 'embeddings': np.array(list(embeddings.values())),
191
+ 'model': self.model_name,
192
+ 'embedding_dim': self.embedding_dim
193
+ }
194
+ np.savez(output_path, **data)
195
+ print(f"Saved CLIP embeddings to {output_path}")
196
+
197
+ @staticmethod
198
+ def load_embeddings(input_path: str) -> Dict[str, np.ndarray]:
199
+ """Load embeddings from a numpy file."""
200
+ data = np.load(input_path, allow_pickle=True)
201
+ filenames = data['filenames']
202
+ embeddings_array = data['embeddings']
203
+ return {fn: emb for fn, emb in zip(filenames, embeddings_array)}
204
+
205
+
206
+ def compute_similarity(emb1: np.ndarray, emb2: np.ndarray) -> float:
207
+ """Compute cosine similarity between two embeddings."""
208
+ return float(np.dot(emb1, emb2))
209
+
210
+
211
+ def find_similar_photos(embeddings: Dict[str, np.ndarray],
212
+ query_filename: str,
213
+ top_k: int = 10) -> List[Tuple[str, float]]:
214
+ """Find most similar photos to a query photo."""
215
+ query_emb = embeddings[query_filename]
216
+
217
+ similarities = []
218
+ for filename, emb in embeddings.items():
219
+ if filename != query_filename:
220
+ sim = compute_similarity(query_emb, emb)
221
+ similarities.append((filename, sim))
222
+
223
+ similarities.sort(key=lambda x: x[1], reverse=True)
224
+ return similarities[:top_k]
225
+
226
+
227
+ if __name__ == "__main__":
228
+ import sys
229
+
230
+ if len(sys.argv) > 1:
231
+ folder = sys.argv[1]
232
+ else:
233
+ print("Usage: python clip_embeddings.py <folder_path>")
234
+ print("\nThis will generate CLIP embeddings for all images in the folder.")
235
+ sys.exit(0)
236
+
237
+ embedder = CLIPEmbedder()
238
+ embeddings = embedder.process_folder(folder)
239
+
240
+ output_dir = os.path.dirname(os.path.abspath(__file__))
241
+ embedder.save_embeddings(embeddings, os.path.join(output_dir, "clip_embeddings.npz"))
supabase_storage.py CHANGED
@@ -44,15 +44,26 @@ def is_supabase_available() -> bool:
44
 
45
 
46
  def _get_dataset_registry(client) -> List[str]:
47
- """Get the list of dataset names from the registry file."""
 
 
 
 
48
  try:
49
  storage = client.storage.from_(BUCKET_NAME)
50
  response = storage.download("_registry.json")
51
  registry = json.loads(response.decode('utf-8'))
52
  return registry.get('datasets', [])
53
- except Exception:
54
- # Registry doesn't exist yet
55
- return []
 
 
 
 
 
 
 
56
 
57
 
58
  def _update_dataset_registry(client, dataset_name: str, action: str = 'add'):
@@ -63,6 +74,11 @@ def _update_dataset_registry(client, dataset_name: str, action: str = 'add'):
63
  # Get current registry
64
  datasets = _get_dataset_registry(client)
65
 
 
 
 
 
 
66
  if action == 'add' and dataset_name not in datasets:
67
  datasets.append(dataset_name)
68
  elif action == 'remove' and dataset_name in datasets:
@@ -220,6 +236,11 @@ def list_datasets_from_supabase() -> List[Dict[str, Any]]:
220
  dataset_names = _get_dataset_registry(client)
221
  print(f"[Supabase] Registry contains: {dataset_names}")
222
 
 
 
 
 
 
223
  # If registry is empty, try to find existing datasets by checking known names
224
  # This handles the case where datasets were saved before registry was implemented
225
  if not dataset_names:
 
44
 
45
 
46
  def _get_dataset_registry(client) -> List[str]:
47
+ """
48
+ Get the list of dataset names from the registry file.
49
+ Returns None if there's an error reading (to prevent accidental overwrite).
50
+ Returns [] only if file doesn't exist yet.
51
+ """
52
  try:
53
  storage = client.storage.from_(BUCKET_NAME)
54
  response = storage.download("_registry.json")
55
  registry = json.loads(response.decode('utf-8'))
56
  return registry.get('datasets', [])
57
+ except Exception as e:
58
+ error_str = str(e).lower()
59
+ # Only return empty if file doesn't exist (not for other errors)
60
+ if 'not found' in error_str or '404' in error_str or 'does not exist' in error_str:
61
+ print("[Supabase] Registry file doesn't exist yet, starting fresh")
62
+ return []
63
+ else:
64
+ # For other errors, return None to prevent accidental overwrite
65
+ print(f"[Supabase] ERROR reading registry: {e}")
66
+ return None
67
 
68
 
69
  def _update_dataset_registry(client, dataset_name: str, action: str = 'add'):
 
74
  # Get current registry
75
  datasets = _get_dataset_registry(client)
76
 
77
+ # If we couldn't read the registry (error, not "not found"), don't overwrite
78
+ if datasets is None:
79
+ print(f"[Supabase] Skipping registry update - couldn't read existing registry safely")
80
+ return
81
+
82
  if action == 'add' and dataset_name not in datasets:
83
  datasets.append(dataset_name)
84
  elif action == 'remove' and dataset_name in datasets:
 
236
  dataset_names = _get_dataset_registry(client)
237
  print(f"[Supabase] Registry contains: {dataset_names}")
238
 
239
+ # If registry read failed (None), return empty to be safe
240
+ if dataset_names is None:
241
+ print("[Supabase] Could not read registry, returning empty list")
242
+ return []
243
+
244
  # If registry is empty, try to find existing datasets by checking known names
245
  # This handles the case where datasets were saved before registry was implemented
246
  if not dataset_names:
templates/step3_review.html CHANGED
@@ -835,6 +835,106 @@
835
  padding: 40px;
836
  color: #666;
837
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
838
  </style>
839
  </head>
840
  <body>
@@ -982,6 +1082,37 @@
982
  <div class="proceed-section">
983
  <h3>Ready to Continue?</h3>
984
  <p>Click below to run quality selection on <strong id="final-count">0</strong> selected photos</p>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
985
  <div class="proceed-buttons">
986
  <button class="btn btn-success btn-lg" onclick="proceedToSelection()">
987
  Continue to Quality Selection &rarr;
@@ -1036,6 +1167,19 @@
1036
  let photoSelections = {};
1037
  let currentModalPhoto = null;
1038
  let unmatchedLoaded = false;
 
 
 
 
 
 
 
 
 
 
 
 
 
1039
 
1040
  async function loadFilteredPhotos() {
1041
  showLoading('Loading filtered photos...');
@@ -1267,13 +1411,17 @@
1267
  return;
1268
  }
1269
 
1270
- showLoading('Running quality-based selection...');
 
1271
 
1272
  try {
1273
  const response = await fetch(`/confirm_selection/${jobId}`, {
1274
  method: 'POST',
1275
  headers: { 'Content-Type': 'application/json' },
1276
- body: JSON.stringify({ selected_photos: selectedPhotos })
 
 
 
1277
  });
1278
 
1279
  const data = await response.json();
 
835
  padding: 40px;
836
  color: #666;
837
  }
838
+
839
+ /* Model Selection */
840
+ .model-selection {
841
+ background: #f8f9fa;
842
+ border-radius: 12px;
843
+ padding: 20px;
844
+ margin-bottom: 25px;
845
+ border: 1px solid #e0e0e0;
846
+ }
847
+
848
+ .model-selection-title {
849
+ font-size: 14px;
850
+ font-weight: 600;
851
+ color: #374151;
852
+ margin-bottom: 12px;
853
+ display: flex;
854
+ align-items: center;
855
+ gap: 8px;
856
+ }
857
+
858
+ .model-options {
859
+ display: flex;
860
+ gap: 15px;
861
+ flex-wrap: wrap;
862
+ }
863
+
864
+ .model-option {
865
+ flex: 1;
866
+ min-width: 200px;
867
+ background: white;
868
+ border: 2px solid #e0e0e0;
869
+ border-radius: 10px;
870
+ padding: 15px;
871
+ cursor: pointer;
872
+ transition: all 0.2s;
873
+ }
874
+
875
+ .model-option:hover {
876
+ border-color: #667eea;
877
+ }
878
+
879
+ .model-option.selected {
880
+ border-color: #667eea;
881
+ background: linear-gradient(135deg, rgba(102, 126, 234, 0.05) 0%, rgba(118, 75, 162, 0.05) 100%);
882
+ }
883
+
884
+ .model-option input[type="radio"] {
885
+ display: none;
886
+ }
887
+
888
+ .model-option-header {
889
+ display: flex;
890
+ align-items: center;
891
+ gap: 10px;
892
+ margin-bottom: 8px;
893
+ }
894
+
895
+ .model-radio {
896
+ width: 20px;
897
+ height: 20px;
898
+ border: 2px solid #ccc;
899
+ border-radius: 50%;
900
+ display: flex;
901
+ align-items: center;
902
+ justify-content: center;
903
+ flex-shrink: 0;
904
+ }
905
+
906
+ .model-option.selected .model-radio {
907
+ border-color: #667eea;
908
+ }
909
+
910
+ .model-option.selected .model-radio::after {
911
+ content: '';
912
+ width: 10px;
913
+ height: 10px;
914
+ background: #667eea;
915
+ border-radius: 50%;
916
+ }
917
+
918
+ .model-name {
919
+ font-weight: 600;
920
+ color: #333;
921
+ }
922
+
923
+ .model-badge {
924
+ font-size: 10px;
925
+ padding: 2px 8px;
926
+ border-radius: 10px;
927
+ background: #4CAF50;
928
+ color: white;
929
+ font-weight: 500;
930
+ }
931
+
932
+ .model-description {
933
+ font-size: 13px;
934
+ color: #666;
935
+ line-height: 1.4;
936
+ margin-left: 30px;
937
+ }
938
  </style>
939
  </head>
940
  <body>
 
1082
  <div class="proceed-section">
1083
  <h3>Ready to Continue?</h3>
1084
  <p>Click below to run quality selection on <strong id="final-count">0</strong> selected photos</p>
1085
+
1086
+ <!-- Model Selection -->
1087
+ <div class="model-selection">
1088
+ <div class="model-selection-title">
1089
+ Clustering Model
1090
+ </div>
1091
+ <div class="model-options">
1092
+ <label class="model-option selected" onclick="selectModel('siglip')">
1093
+ <input type="radio" name="embedding_model" value="siglip" checked>
1094
+ <div class="model-option-header">
1095
+ <div class="model-radio"></div>
1096
+ <span class="model-name">SigLIP</span>
1097
+ <span class="model-badge">Recommended</span>
1098
+ </div>
1099
+ <div class="model-description">
1100
+ Better for fine-grained visual understanding. 768-dim embeddings.
1101
+ </div>
1102
+ </label>
1103
+ <label class="model-option" onclick="selectModel('clip')">
1104
+ <input type="radio" name="embedding_model" value="clip">
1105
+ <div class="model-option-header">
1106
+ <div class="model-radio"></div>
1107
+ <span class="model-name">CLIP</span>
1108
+ </div>
1109
+ <div class="model-description">
1110
+ Original OpenAI model. 512-dim embeddings. Good general-purpose.
1111
+ </div>
1112
+ </label>
1113
+ </div>
1114
+ </div>
1115
+
1116
  <div class="proceed-buttons">
1117
  <button class="btn btn-success btn-lg" onclick="proceedToSelection()">
1118
  Continue to Quality Selection &rarr;
 
1167
  let photoSelections = {};
1168
  let currentModalPhoto = null;
1169
  let unmatchedLoaded = false;
1170
+ let selectedModel = 'siglip'; // Default model
1171
+
1172
+ function selectModel(model) {
1173
+ selectedModel = model;
1174
+ // Update UI
1175
+ document.querySelectorAll('.model-option').forEach(opt => {
1176
+ opt.classList.remove('selected');
1177
+ if (opt.querySelector(`input[value="${model}"]`)) {
1178
+ opt.classList.add('selected');
1179
+ opt.querySelector('input').checked = true;
1180
+ }
1181
+ });
1182
+ }
1183
 
1184
  async function loadFilteredPhotos() {
1185
  showLoading('Loading filtered photos...');
 
1411
  return;
1412
  }
1413
 
1414
+ const modelName = selectedModel === 'clip' ? 'CLIP' : 'SigLIP';
1415
+ showLoading(`Running quality-based selection with ${modelName}...`);
1416
 
1417
  try {
1418
  const response = await fetch(`/confirm_selection/${jobId}`, {
1419
  method: 'POST',
1420
  headers: { 'Content-Type': 'application/json' },
1421
+ body: JSON.stringify({
1422
+ selected_photos: selectedPhotos,
1423
+ embedding_model: selectedModel
1424
+ })
1425
  });
1426
 
1427
  const data = await response.json();