astronolan commited on
Commit
b944de3
Β·
1 Parent(s): cf36170

Cleaned up

Browse files
clip/evaluation/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- """Evaluation utilities for CLIP model."""
2
-
3
- from .inference import ClipInferenceModel
4
-
5
- __all__ = ["ClipInferenceModel"]
 
 
 
 
 
 
clip/evaluation/inference.py DELETED
@@ -1,82 +0,0 @@
1
- """
2
- Inference utilities for trained CLIP model.
3
- """
4
-
5
- import torch
6
- import torch.nn.functional as F
7
- import numpy as np
8
- from pathlib import Path
9
- from typing import Union, List, Dict, Tuple
10
- import logging
11
-
12
- from ..models import GalaxyClipModel
13
-
14
- logger = logging.getLogger(__name__)
15
-
16
-
17
- class ClipInferenceModel:
18
- """Wrapper for using trained CLIP model for inference and search."""
19
-
20
- def __init__(self, model_path: str, device: str = "cpu"):
21
- """
22
- Initialize inference model.
23
-
24
- Args:
25
- model_path: Path to saved model (.pt file)
26
- device: Device to use for inference
27
- """
28
- self.device = torch.device(device)
29
-
30
- # Load model
31
- checkpoint = torch.load(model_path, map_location=self.device)
32
- model_config = checkpoint['model_config']
33
-
34
- # Create model with same config
35
- self.model = GalaxyClipModel(
36
- image_input_dim=model_config['image_input_dim'],
37
- text_input_dim=model_config['text_input_dim'],
38
- embedding_dim=model_config['embedding_dim']
39
- )
40
-
41
- # Load weights
42
- self.model.load_state_dict(checkpoint['model_state_dict'])
43
- self.model.to(self.device)
44
- self.model.eval()
45
-
46
- self.config = model_config
47
- logger.info(f"Loaded CLIP model on {device}")
48
- logger.info(f"Model config: {model_config}")
49
-
50
- def encode_images(self, image_embeddings):
51
- """Encode image embeddings to shared space."""
52
-
53
- tensor = torch.as_tensor(image_embeddings, dtype=torch.float, device=self.device)
54
-
55
- if tensor.ndim == 1:
56
- tensor = tensor.unsqueeze(0)
57
- squeeze = True
58
- else:
59
- squeeze = False
60
-
61
- with torch.no_grad():
62
- # Use image_projector and normalize
63
- out = self.model.image_projector(tensor)
64
-
65
- return out.squeeze(0).cpu() if squeeze else out.cpu()
66
-
67
- def encode_texts(self, text_embeddings):
68
- """Encode text embeddings to shared space."""
69
-
70
- tensor = torch.as_tensor(text_embeddings, dtype=torch.float, device=self.device)
71
-
72
- if tensor.ndim == 1:
73
- tensor = tensor.unsqueeze(0)
74
- squeeze = True
75
- else:
76
- squeeze = False
77
-
78
- with torch.no_grad():
79
- # Use text_projector and normalize
80
- out = self.model.text_projector(tensor)
81
-
82
- return out.squeeze(0).cpu() if squeeze else out.cpu()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clip/utils/__init__.py DELETED
@@ -1,10 +0,0 @@
1
- """Utility functions for CLIP training and evaluation."""
2
-
3
- from .logging_utils import setup_logging
4
- from .io_utils import save_clip_embeddings_hdf5, inspect_generated_files
5
-
6
- __all__ = [
7
- "setup_logging",
8
- "save_clip_embeddings_hdf5",
9
- "inspect_generated_files"
10
- ]
 
 
 
 
 
 
 
 
 
 
 
clip/utils/data_loader.py DELETED
@@ -1,250 +0,0 @@
1
- """
2
- Data loader for multi-text training using unified parquet file with nested text embeddings.
3
- This loader handles the new unified format from 05_generate_unified_embeddings.py.
4
- """
5
-
6
- import numpy as np
7
- import pandas as pd
8
- import torch
9
- from torch.utils.data import Dataset, DataLoader
10
- import logging
11
- from pathlib import Path
12
- import random
13
-
14
- logger = logging.getLogger(__name__)
15
-
16
-
17
- class UnifiedMultiTextDataset(Dataset):
18
- """Dataset for unified parquet file with multiple text embeddings per galaxy."""
19
-
20
- def __init__(self, parquet_path, split="train", train_ratio=0.8,
21
- text_sampling_strategy="random", epoch=0, max_train_samples=None,
22
- num_embedding=None):
23
- self.parquet_path = Path(parquet_path)
24
- self.split = split
25
- self.train_ratio = train_ratio
26
- self.text_sampling_strategy = text_sampling_strategy
27
- self.epoch = epoch
28
- self.max_train_samples = max_train_samples
29
- self.num_embedding = num_embedding
30
-
31
- # Load the parquet file
32
- logger.info(f"Loading unified embeddings from {self.parquet_path}")
33
- self.df = pd.read_parquet(self.parquet_path)
34
-
35
- # Create train/val split based on galaxy_index
36
- n_samples = len(self.df)
37
- indices = np.arange(n_samples)
38
- self.seed = 42
39
-
40
- # Deterministic split based on galaxy_index
41
- split_mask = []
42
- for idx in range(n_samples):
43
- galaxy_idx = self.df.iloc[idx]['galaxy_index']
44
- # Hash the galaxy index for deterministic assignment
45
- sample_hash = hash((galaxy_idx, self.seed)) % 10000 / 10000.0
46
- is_train = sample_hash < self.train_ratio
47
- split_mask.append(is_train)
48
-
49
- split_mask = np.array(split_mask)
50
-
51
- if split == "train":
52
- self.indices = indices[split_mask]
53
- # Limit training samples if specified
54
- if self.max_train_samples is not None and len(self.indices) > self.max_train_samples:
55
- rng = np.random.RandomState(self.seed)
56
- selected_indices = rng.choice(self.indices, size=self.max_train_samples, replace=False)
57
- self.indices = np.sort(selected_indices) # Sort for reproducibility
58
- logger.info(f"Limited training set to {self.max_train_samples} samples")
59
- else:
60
- self.indices = indices[~split_mask]
61
-
62
- logger.info(f"Dataset initialized: {len(self.indices)} samples for {split} split")
63
- logger.info(f"Text sampling strategy: {text_sampling_strategy}")
64
-
65
- # Validate num_embedding parameter for specific_summary strategy
66
- if text_sampling_strategy == "specific_summary" and num_embedding is None:
67
- raise ValueError("num_embedding parameter is required when using 'specific_summary' strategy")
68
-
69
- # Check data structure
70
- sample_row = self.df.iloc[0]
71
- n_augmented = len(sample_row['augmented_embeddings'])
72
- logger.info(f"Each galaxy has 1 original + {n_augmented} augmented embeddings = {1 + n_augmented} total")
73
-
74
- # Validate num_embedding is within valid range
75
- if text_sampling_strategy == "specific_summary":
76
- total_embeddings = 1 + n_augmented
77
- if num_embedding < 0 or num_embedding >= total_embeddings:
78
- raise ValueError(f"num_embedding must be between 0 and {total_embeddings-1}, got {num_embedding}")
79
- logger.info(f"Using specific embedding at index {num_embedding}")
80
-
81
- def __len__(self):
82
- return len(self.indices)
83
-
84
- def set_epoch(self, epoch):
85
- """Set current epoch for round-robin sampling."""
86
- self.epoch = epoch
87
-
88
- def _get_all_embeddings_and_sources(self, row):
89
- """Combine original and augmented embeddings into single lists."""
90
- # Start with original embedding
91
- all_embeddings = [np.array(row['text_embedding'], dtype=np.float32)]
92
- all_sources = [row['description_sources'][0]] # 'original'
93
-
94
- # Add augmented embeddings
95
- for aug_emb, aug_source in zip(row['augmented_embeddings'], row['description_sources'][1:]):
96
- all_embeddings.append(np.array(aug_emb, dtype=np.float32))
97
- all_sources.append(aug_source)
98
-
99
- return all_embeddings, all_sources
100
-
101
- def _sample_text_embedding(self, text_embeddings, text_sources, galaxy_idx):
102
- """Sample one text embedding from multiple options."""
103
- n_texts = len(text_embeddings)
104
-
105
- if self.text_sampling_strategy == "original":
106
- # Always use original text (index 0)
107
- idx = 0
108
- elif self.text_sampling_strategy == "summaries-only":
109
- # Only use summaries (exclude original at index 0)
110
- if n_texts > 1:
111
- rng = random.Random(galaxy_idx + self.epoch * 1000000)
112
- idx = rng.randint(1, n_texts - 1) # Start from 1 to exclude original
113
- else:
114
- # Fallback to original if no summaries available
115
- idx = 0
116
- elif self.text_sampling_strategy == "specific_summary":
117
- # Use the specific embedding index provided
118
- if self.num_embedding < n_texts:
119
- idx = self.num_embedding
120
- else:
121
- # Fallback to original if index out of range
122
- logger.warning(f"Requested embedding index {self.num_embedding} out of range for {n_texts} embeddings, using original")
123
- idx = 0
124
- elif self.text_sampling_strategy == "random":
125
- # Random sampling with seed based on galaxy_idx and epoch
126
- rng = random.Random(galaxy_idx + self.epoch * 1000000)
127
- idx = rng.randint(0, n_texts - 1)
128
- elif self.text_sampling_strategy == "round-robin":
129
- # Cycle through texts based on epoch
130
- idx = (self.epoch + galaxy_idx) % n_texts
131
- elif self.text_sampling_strategy == "weighted":
132
- # Weight towards original (50%) and summaries (50% / n_summaries each)
133
- rng = random.Random(galaxy_idx + self.epoch * 1000000)
134
- n_summaries = n_texts - 1
135
- if n_summaries > 0:
136
- summary_weight = 0.5 / n_summaries
137
- weights = [0.5] + [summary_weight] * n_summaries
138
- else:
139
- weights = [1.0]
140
- idx = rng.choices(range(n_texts), weights=weights)[0]
141
- else:
142
- idx = 0 # Default to original
143
-
144
- return text_embeddings[idx], text_sources[idx], idx
145
-
146
- def __getitem__(self, idx):
147
- """Get a single sample with randomly selected text embedding."""
148
- actual_idx = self.indices[idx]
149
- row = self.df.iloc[actual_idx]
150
-
151
- # Get AION embedding
152
- aion_embedding = np.array(row['aion_embedding'], dtype=np.float32)
153
-
154
- # Get all text embeddings and sources
155
- text_embeddings, text_sources = self._get_all_embeddings_and_sources(row)
156
-
157
- # Sample one text embedding
158
- galaxy_idx = row['galaxy_index']
159
- selected_text, selected_source, text_idx = self._sample_text_embedding(
160
- text_embeddings, text_sources, galaxy_idx
161
- )
162
-
163
- # Log selection details periodically (every 100th sample)
164
- if idx % 100 == 0:
165
- logger.debug(f"Galaxy {galaxy_idx}: Selected {selected_source} (index {text_idx}) from {len(text_sources)} options")
166
-
167
- return {
168
- 'aion_embedding': torch.from_numpy(aion_embedding),
169
- 'text_embedding': torch.from_numpy(selected_text),
170
- 'galaxy_index': galaxy_idx,
171
- 'text_source': selected_source,
172
- 'text_index': text_idx,
173
- 'object_id': row['object_id']
174
- }
175
-
176
-
177
- def create_unified_multi_text_loaders(
178
- unified_embeddings_path,
179
- batch_size=64,
180
- train_ratio=0.8,
181
- pin_memory=True,
182
- text_sampling_strategy="random",
183
- num_workers=4,
184
- max_train_samples=None,
185
- num_embedding=None,
186
- **kwargs
187
- ):
188
- """
189
- Create train and validation data loaders for multi-text training from unified parquet.
190
-
191
- Args:
192
- unified_embeddings_path: Path to unified parquet file
193
- batch_size: Batch size for training
194
- train_ratio: Fraction of samples for training
195
- pin_memory: Whether to pin memory for GPU transfer
196
- text_sampling_strategy: How to sample text embeddings ("original", "summaries-only", "specific_summary", "random", "round-robin", "weighted")
197
- num_workers: Number of data loading workers
198
- max_train_samples: Maximum number of training samples (for data scaling experiments)
199
- num_embedding: When using "specific_summary" strategy, the index of the embedding to use
200
- **kwargs: Additional arguments
201
- """
202
-
203
- # Convert to Path
204
- parquet_path = Path(unified_embeddings_path)
205
-
206
- if not parquet_path.exists():
207
- raise ValueError(f"Unified embeddings file not found: {parquet_path}")
208
-
209
- logger.info(f"Creating unified multi-text data loaders from {parquet_path}")
210
- logger.info(f"Batch size: {batch_size}, Workers: {num_workers}")
211
- logger.info(f"Text sampling strategy: {text_sampling_strategy}")
212
-
213
- # Create datasets
214
- train_dataset = UnifiedMultiTextDataset(
215
- parquet_path=parquet_path,
216
- split="train",
217
- train_ratio=train_ratio,
218
- text_sampling_strategy=text_sampling_strategy,
219
- max_train_samples=max_train_samples,
220
- num_embedding=num_embedding
221
- )
222
-
223
- val_dataset = UnifiedMultiTextDataset(
224
- parquet_path=parquet_path,
225
- split="val",
226
- train_ratio=train_ratio,
227
- text_sampling_strategy=text_sampling_strategy,
228
- num_embedding=num_embedding
229
- )
230
-
231
- # Create loaders
232
- train_loader = DataLoader(
233
- train_dataset,
234
- batch_size=batch_size,
235
- shuffle=True, # Shuffle within the train split
236
- num_workers=num_workers,
237
- pin_memory=pin_memory,
238
- drop_last=True # Drop incomplete batches for stable training
239
- )
240
-
241
- val_loader = DataLoader(
242
- val_dataset,
243
- batch_size=batch_size,
244
- shuffle=False, # No shuffle for validation
245
- num_workers=num_workers,
246
- pin_memory=pin_memory,
247
- drop_last=False
248
- )
249
-
250
- return train_loader, val_loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clip/utils/io_utils.py DELETED
@@ -1,103 +0,0 @@
1
- """
2
- I/O utilities for saving and loading CLIP embeddings.
3
- """
4
-
5
- import h5py
6
- import numpy as np
7
- from pathlib import Path
8
- from datetime import datetime
9
- import logging
10
-
11
- logger = logging.getLogger(__name__)
12
-
13
-
14
- def save_clip_embeddings_hdf5(
15
- object_ids,
16
- galaxy_data,
17
- text_data,
18
- aion_clip_embeddings,
19
- text_clip_embeddings,
20
- output_dir="data/processed"
21
- ):
22
- """Save CLIP embeddings to separate HDF5 files."""
23
- output_dir = Path(output_dir)
24
- output_dir.mkdir(parents=True, exist_ok=True)
25
-
26
- # File paths (standardized names)
27
- aion_clip_path = output_dir / "galaxy_aion_clip_embeddings.hdf5"
28
- text_clip_path = output_dir / "galaxy_text_clip_embeddings.hdf5"
29
-
30
- logger.info(f"Saving AION CLIP embeddings to: {aion_clip_path}")
31
-
32
- # Save AION CLIP embeddings
33
- with h5py.File(aion_clip_path, 'w') as f:
34
- # Object IDs
35
- dt = h5py.special_dtype(vlen=str)
36
- f.create_dataset('object_id', data=[str(oid) for oid in object_ids], dtype=dt)
37
-
38
- # Coordinates and metadata
39
- ra_values = np.array([galaxy_data[oid]['ra'] for oid in object_ids])
40
- dec_values = np.array([galaxy_data[oid]['dec'] for oid in object_ids])
41
- healpix_values = np.array([galaxy_data[oid]['healpix'] for oid in object_ids])
42
-
43
- f.create_dataset('ra', data=ra_values, dtype=np.float64)
44
- f.create_dataset('dec', data=dec_values, dtype=np.float64)
45
- f.create_dataset('healpix', data=healpix_values, dtype=np.int64)
46
-
47
- # AION CLIP embeddings
48
- f.create_dataset('AION_clip_embedding', data=aion_clip_embeddings, dtype=np.float32)
49
-
50
- # Metadata
51
- f.attrs['description'] = 'AION embeddings encoded through trained CLIP model'
52
- f.attrs['embedding_dim'] = aion_clip_embeddings.shape[1]
53
- f.attrs['num_objects'] = len(object_ids)
54
- f.attrs['created'] = datetime.now().isoformat()
55
-
56
- logger.info(f"Saving text CLIP embeddings to: {text_clip_path}")
57
-
58
- # Save text CLIP embeddings
59
- with h5py.File(text_clip_path, 'w') as f:
60
- # Object IDs
61
- dt = h5py.special_dtype(vlen=str)
62
- f.create_dataset('object_id', data=[str(oid) for oid in object_ids], dtype=dt)
63
-
64
- # Coordinates and metadata (use text data for consistency)
65
- ra_values = np.array([text_data[oid]['ra'] for oid in object_ids])
66
- dec_values = np.array([text_data[oid]['dec'] for oid in object_ids])
67
- healpix_values = np.array([text_data[oid]['healpix'] for oid in object_ids])
68
-
69
- f.create_dataset('ra', data=ra_values, dtype=np.float64)
70
- f.create_dataset('dec', data=dec_values, dtype=np.float64)
71
- f.create_dataset('healpix', data=healpix_values, dtype=np.int64)
72
-
73
- # Text CLIP embeddings
74
- f.create_dataset('text_clip_embedding', data=text_clip_embeddings, dtype=np.float32)
75
-
76
- # Metadata
77
- f.attrs['description'] = 'Text embeddings encoded through trained CLIP model'
78
- f.attrs['embedding_dim'] = text_clip_embeddings.shape[1]
79
- f.attrs['num_objects'] = len(object_ids)
80
- f.attrs['created'] = datetime.now().isoformat()
81
-
82
- return aion_clip_path, text_clip_path
83
-
84
-
85
- def inspect_generated_files(aion_clip_path, text_clip_path):
86
- """Inspect the generated HDF5 files."""
87
- logger.info("Inspecting generated AION CLIP embeddings file...")
88
-
89
- with h5py.File(aion_clip_path, 'r') as f:
90
- logger.info(f"AION file datasets: {list(f.keys())}")
91
- for key in f.keys():
92
- dataset = f[key]
93
- logger.info(f" {key}: shape={dataset.shape}, dtype={dataset.dtype}")
94
- logger.info(f" Attributes: {dict(f.attrs)}")
95
-
96
- logger.info("Inspecting generated text CLIP embeddings file...")
97
-
98
- with h5py.File(text_clip_path, 'r') as f:
99
- logger.info(f"Text file datasets: {list(f.keys())}")
100
- for key in f.keys():
101
- dataset = f[key]
102
- logger.info(f" {key}: shape={dataset.shape}, dtype={dataset.dtype}")
103
- logger.info(f" Attributes: {dict(f.attrs)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
clip/utils/logging_utils.py DELETED
@@ -1,42 +0,0 @@
1
- """Logging utilities."""
2
-
3
- import logging
4
- import sys
5
- from pathlib import Path
6
-
7
-
8
- def setup_logging(log_level: str = "INFO", log_file: str = None):
9
- """
10
- Setup logging configuration.
11
-
12
- Args:
13
- log_level: Logging level (DEBUG, INFO, WARNING, ERROR)
14
- log_file: Optional path to log file
15
- """
16
- # Clear any existing handlers
17
- logging.getLogger().handlers.clear()
18
-
19
- # Create formatter
20
- formatter = logging.Formatter(
21
- '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
22
- )
23
-
24
- # Console handler
25
- console_handler = logging.StreamHandler(sys.stdout)
26
- console_handler.setFormatter(formatter)
27
-
28
- # Setup root logger
29
- logger = logging.getLogger()
30
- logger.setLevel(getattr(logging, log_level.upper()))
31
- logger.addHandler(console_handler)
32
-
33
- # File handler if specified
34
- if log_file:
35
- log_path = Path(log_file)
36
- log_path.parent.mkdir(parents=True, exist_ok=True)
37
-
38
- file_handler = logging.FileHandler(log_path)
39
- file_handler.setFormatter(formatter)
40
- logger.addHandler(file_handler)
41
-
42
- return logger
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py DELETED
@@ -1,6 +0,0 @@
1
- def main():
2
- print("Hello from aion-search!")
3
-
4
-
5
- if __name__ == "__main__":
6
- main()