Spaces:
Running
Running
Commit
Β·
b944de3
1
Parent(s):
cf36170
Cleaned up
Browse files- clip/evaluation/__init__.py +0 -5
- clip/evaluation/inference.py +0 -82
- clip/utils/__init__.py +0 -10
- clip/utils/data_loader.py +0 -250
- clip/utils/io_utils.py +0 -103
- clip/utils/logging_utils.py +0 -42
- main.py +0 -6
clip/evaluation/__init__.py
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
"""Evaluation utilities for CLIP model."""
|
| 2 |
-
|
| 3 |
-
from .inference import ClipInferenceModel
|
| 4 |
-
|
| 5 |
-
__all__ = ["ClipInferenceModel"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clip/evaluation/inference.py
DELETED
|
@@ -1,82 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Inference utilities for trained CLIP model.
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
-
import numpy as np
|
| 8 |
-
from pathlib import Path
|
| 9 |
-
from typing import Union, List, Dict, Tuple
|
| 10 |
-
import logging
|
| 11 |
-
|
| 12 |
-
from ..models import GalaxyClipModel
|
| 13 |
-
|
| 14 |
-
logger = logging.getLogger(__name__)
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class ClipInferenceModel:
|
| 18 |
-
"""Wrapper for using trained CLIP model for inference and search."""
|
| 19 |
-
|
| 20 |
-
def __init__(self, model_path: str, device: str = "cpu"):
|
| 21 |
-
"""
|
| 22 |
-
Initialize inference model.
|
| 23 |
-
|
| 24 |
-
Args:
|
| 25 |
-
model_path: Path to saved model (.pt file)
|
| 26 |
-
device: Device to use for inference
|
| 27 |
-
"""
|
| 28 |
-
self.device = torch.device(device)
|
| 29 |
-
|
| 30 |
-
# Load model
|
| 31 |
-
checkpoint = torch.load(model_path, map_location=self.device)
|
| 32 |
-
model_config = checkpoint['model_config']
|
| 33 |
-
|
| 34 |
-
# Create model with same config
|
| 35 |
-
self.model = GalaxyClipModel(
|
| 36 |
-
image_input_dim=model_config['image_input_dim'],
|
| 37 |
-
text_input_dim=model_config['text_input_dim'],
|
| 38 |
-
embedding_dim=model_config['embedding_dim']
|
| 39 |
-
)
|
| 40 |
-
|
| 41 |
-
# Load weights
|
| 42 |
-
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 43 |
-
self.model.to(self.device)
|
| 44 |
-
self.model.eval()
|
| 45 |
-
|
| 46 |
-
self.config = model_config
|
| 47 |
-
logger.info(f"Loaded CLIP model on {device}")
|
| 48 |
-
logger.info(f"Model config: {model_config}")
|
| 49 |
-
|
| 50 |
-
def encode_images(self, image_embeddings):
|
| 51 |
-
"""Encode image embeddings to shared space."""
|
| 52 |
-
|
| 53 |
-
tensor = torch.as_tensor(image_embeddings, dtype=torch.float, device=self.device)
|
| 54 |
-
|
| 55 |
-
if tensor.ndim == 1:
|
| 56 |
-
tensor = tensor.unsqueeze(0)
|
| 57 |
-
squeeze = True
|
| 58 |
-
else:
|
| 59 |
-
squeeze = False
|
| 60 |
-
|
| 61 |
-
with torch.no_grad():
|
| 62 |
-
# Use image_projector and normalize
|
| 63 |
-
out = self.model.image_projector(tensor)
|
| 64 |
-
|
| 65 |
-
return out.squeeze(0).cpu() if squeeze else out.cpu()
|
| 66 |
-
|
| 67 |
-
def encode_texts(self, text_embeddings):
|
| 68 |
-
"""Encode text embeddings to shared space."""
|
| 69 |
-
|
| 70 |
-
tensor = torch.as_tensor(text_embeddings, dtype=torch.float, device=self.device)
|
| 71 |
-
|
| 72 |
-
if tensor.ndim == 1:
|
| 73 |
-
tensor = tensor.unsqueeze(0)
|
| 74 |
-
squeeze = True
|
| 75 |
-
else:
|
| 76 |
-
squeeze = False
|
| 77 |
-
|
| 78 |
-
with torch.no_grad():
|
| 79 |
-
# Use text_projector and normalize
|
| 80 |
-
out = self.model.text_projector(tensor)
|
| 81 |
-
|
| 82 |
-
return out.squeeze(0).cpu() if squeeze else out.cpu()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clip/utils/__init__.py
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
"""Utility functions for CLIP training and evaluation."""
|
| 2 |
-
|
| 3 |
-
from .logging_utils import setup_logging
|
| 4 |
-
from .io_utils import save_clip_embeddings_hdf5, inspect_generated_files
|
| 5 |
-
|
| 6 |
-
__all__ = [
|
| 7 |
-
"setup_logging",
|
| 8 |
-
"save_clip_embeddings_hdf5",
|
| 9 |
-
"inspect_generated_files"
|
| 10 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clip/utils/data_loader.py
DELETED
|
@@ -1,250 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Data loader for multi-text training using unified parquet file with nested text embeddings.
|
| 3 |
-
This loader handles the new unified format from 05_generate_unified_embeddings.py.
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import numpy as np
|
| 7 |
-
import pandas as pd
|
| 8 |
-
import torch
|
| 9 |
-
from torch.utils.data import Dataset, DataLoader
|
| 10 |
-
import logging
|
| 11 |
-
from pathlib import Path
|
| 12 |
-
import random
|
| 13 |
-
|
| 14 |
-
logger = logging.getLogger(__name__)
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class UnifiedMultiTextDataset(Dataset):
|
| 18 |
-
"""Dataset for unified parquet file with multiple text embeddings per galaxy."""
|
| 19 |
-
|
| 20 |
-
def __init__(self, parquet_path, split="train", train_ratio=0.8,
|
| 21 |
-
text_sampling_strategy="random", epoch=0, max_train_samples=None,
|
| 22 |
-
num_embedding=None):
|
| 23 |
-
self.parquet_path = Path(parquet_path)
|
| 24 |
-
self.split = split
|
| 25 |
-
self.train_ratio = train_ratio
|
| 26 |
-
self.text_sampling_strategy = text_sampling_strategy
|
| 27 |
-
self.epoch = epoch
|
| 28 |
-
self.max_train_samples = max_train_samples
|
| 29 |
-
self.num_embedding = num_embedding
|
| 30 |
-
|
| 31 |
-
# Load the parquet file
|
| 32 |
-
logger.info(f"Loading unified embeddings from {self.parquet_path}")
|
| 33 |
-
self.df = pd.read_parquet(self.parquet_path)
|
| 34 |
-
|
| 35 |
-
# Create train/val split based on galaxy_index
|
| 36 |
-
n_samples = len(self.df)
|
| 37 |
-
indices = np.arange(n_samples)
|
| 38 |
-
self.seed = 42
|
| 39 |
-
|
| 40 |
-
# Deterministic split based on galaxy_index
|
| 41 |
-
split_mask = []
|
| 42 |
-
for idx in range(n_samples):
|
| 43 |
-
galaxy_idx = self.df.iloc[idx]['galaxy_index']
|
| 44 |
-
# Hash the galaxy index for deterministic assignment
|
| 45 |
-
sample_hash = hash((galaxy_idx, self.seed)) % 10000 / 10000.0
|
| 46 |
-
is_train = sample_hash < self.train_ratio
|
| 47 |
-
split_mask.append(is_train)
|
| 48 |
-
|
| 49 |
-
split_mask = np.array(split_mask)
|
| 50 |
-
|
| 51 |
-
if split == "train":
|
| 52 |
-
self.indices = indices[split_mask]
|
| 53 |
-
# Limit training samples if specified
|
| 54 |
-
if self.max_train_samples is not None and len(self.indices) > self.max_train_samples:
|
| 55 |
-
rng = np.random.RandomState(self.seed)
|
| 56 |
-
selected_indices = rng.choice(self.indices, size=self.max_train_samples, replace=False)
|
| 57 |
-
self.indices = np.sort(selected_indices) # Sort for reproducibility
|
| 58 |
-
logger.info(f"Limited training set to {self.max_train_samples} samples")
|
| 59 |
-
else:
|
| 60 |
-
self.indices = indices[~split_mask]
|
| 61 |
-
|
| 62 |
-
logger.info(f"Dataset initialized: {len(self.indices)} samples for {split} split")
|
| 63 |
-
logger.info(f"Text sampling strategy: {text_sampling_strategy}")
|
| 64 |
-
|
| 65 |
-
# Validate num_embedding parameter for specific_summary strategy
|
| 66 |
-
if text_sampling_strategy == "specific_summary" and num_embedding is None:
|
| 67 |
-
raise ValueError("num_embedding parameter is required when using 'specific_summary' strategy")
|
| 68 |
-
|
| 69 |
-
# Check data structure
|
| 70 |
-
sample_row = self.df.iloc[0]
|
| 71 |
-
n_augmented = len(sample_row['augmented_embeddings'])
|
| 72 |
-
logger.info(f"Each galaxy has 1 original + {n_augmented} augmented embeddings = {1 + n_augmented} total")
|
| 73 |
-
|
| 74 |
-
# Validate num_embedding is within valid range
|
| 75 |
-
if text_sampling_strategy == "specific_summary":
|
| 76 |
-
total_embeddings = 1 + n_augmented
|
| 77 |
-
if num_embedding < 0 or num_embedding >= total_embeddings:
|
| 78 |
-
raise ValueError(f"num_embedding must be between 0 and {total_embeddings-1}, got {num_embedding}")
|
| 79 |
-
logger.info(f"Using specific embedding at index {num_embedding}")
|
| 80 |
-
|
| 81 |
-
def __len__(self):
|
| 82 |
-
return len(self.indices)
|
| 83 |
-
|
| 84 |
-
def set_epoch(self, epoch):
|
| 85 |
-
"""Set current epoch for round-robin sampling."""
|
| 86 |
-
self.epoch = epoch
|
| 87 |
-
|
| 88 |
-
def _get_all_embeddings_and_sources(self, row):
|
| 89 |
-
"""Combine original and augmented embeddings into single lists."""
|
| 90 |
-
# Start with original embedding
|
| 91 |
-
all_embeddings = [np.array(row['text_embedding'], dtype=np.float32)]
|
| 92 |
-
all_sources = [row['description_sources'][0]] # 'original'
|
| 93 |
-
|
| 94 |
-
# Add augmented embeddings
|
| 95 |
-
for aug_emb, aug_source in zip(row['augmented_embeddings'], row['description_sources'][1:]):
|
| 96 |
-
all_embeddings.append(np.array(aug_emb, dtype=np.float32))
|
| 97 |
-
all_sources.append(aug_source)
|
| 98 |
-
|
| 99 |
-
return all_embeddings, all_sources
|
| 100 |
-
|
| 101 |
-
def _sample_text_embedding(self, text_embeddings, text_sources, galaxy_idx):
|
| 102 |
-
"""Sample one text embedding from multiple options."""
|
| 103 |
-
n_texts = len(text_embeddings)
|
| 104 |
-
|
| 105 |
-
if self.text_sampling_strategy == "original":
|
| 106 |
-
# Always use original text (index 0)
|
| 107 |
-
idx = 0
|
| 108 |
-
elif self.text_sampling_strategy == "summaries-only":
|
| 109 |
-
# Only use summaries (exclude original at index 0)
|
| 110 |
-
if n_texts > 1:
|
| 111 |
-
rng = random.Random(galaxy_idx + self.epoch * 1000000)
|
| 112 |
-
idx = rng.randint(1, n_texts - 1) # Start from 1 to exclude original
|
| 113 |
-
else:
|
| 114 |
-
# Fallback to original if no summaries available
|
| 115 |
-
idx = 0
|
| 116 |
-
elif self.text_sampling_strategy == "specific_summary":
|
| 117 |
-
# Use the specific embedding index provided
|
| 118 |
-
if self.num_embedding < n_texts:
|
| 119 |
-
idx = self.num_embedding
|
| 120 |
-
else:
|
| 121 |
-
# Fallback to original if index out of range
|
| 122 |
-
logger.warning(f"Requested embedding index {self.num_embedding} out of range for {n_texts} embeddings, using original")
|
| 123 |
-
idx = 0
|
| 124 |
-
elif self.text_sampling_strategy == "random":
|
| 125 |
-
# Random sampling with seed based on galaxy_idx and epoch
|
| 126 |
-
rng = random.Random(galaxy_idx + self.epoch * 1000000)
|
| 127 |
-
idx = rng.randint(0, n_texts - 1)
|
| 128 |
-
elif self.text_sampling_strategy == "round-robin":
|
| 129 |
-
# Cycle through texts based on epoch
|
| 130 |
-
idx = (self.epoch + galaxy_idx) % n_texts
|
| 131 |
-
elif self.text_sampling_strategy == "weighted":
|
| 132 |
-
# Weight towards original (50%) and summaries (50% / n_summaries each)
|
| 133 |
-
rng = random.Random(galaxy_idx + self.epoch * 1000000)
|
| 134 |
-
n_summaries = n_texts - 1
|
| 135 |
-
if n_summaries > 0:
|
| 136 |
-
summary_weight = 0.5 / n_summaries
|
| 137 |
-
weights = [0.5] + [summary_weight] * n_summaries
|
| 138 |
-
else:
|
| 139 |
-
weights = [1.0]
|
| 140 |
-
idx = rng.choices(range(n_texts), weights=weights)[0]
|
| 141 |
-
else:
|
| 142 |
-
idx = 0 # Default to original
|
| 143 |
-
|
| 144 |
-
return text_embeddings[idx], text_sources[idx], idx
|
| 145 |
-
|
| 146 |
-
def __getitem__(self, idx):
|
| 147 |
-
"""Get a single sample with randomly selected text embedding."""
|
| 148 |
-
actual_idx = self.indices[idx]
|
| 149 |
-
row = self.df.iloc[actual_idx]
|
| 150 |
-
|
| 151 |
-
# Get AION embedding
|
| 152 |
-
aion_embedding = np.array(row['aion_embedding'], dtype=np.float32)
|
| 153 |
-
|
| 154 |
-
# Get all text embeddings and sources
|
| 155 |
-
text_embeddings, text_sources = self._get_all_embeddings_and_sources(row)
|
| 156 |
-
|
| 157 |
-
# Sample one text embedding
|
| 158 |
-
galaxy_idx = row['galaxy_index']
|
| 159 |
-
selected_text, selected_source, text_idx = self._sample_text_embedding(
|
| 160 |
-
text_embeddings, text_sources, galaxy_idx
|
| 161 |
-
)
|
| 162 |
-
|
| 163 |
-
# Log selection details periodically (every 100th sample)
|
| 164 |
-
if idx % 100 == 0:
|
| 165 |
-
logger.debug(f"Galaxy {galaxy_idx}: Selected {selected_source} (index {text_idx}) from {len(text_sources)} options")
|
| 166 |
-
|
| 167 |
-
return {
|
| 168 |
-
'aion_embedding': torch.from_numpy(aion_embedding),
|
| 169 |
-
'text_embedding': torch.from_numpy(selected_text),
|
| 170 |
-
'galaxy_index': galaxy_idx,
|
| 171 |
-
'text_source': selected_source,
|
| 172 |
-
'text_index': text_idx,
|
| 173 |
-
'object_id': row['object_id']
|
| 174 |
-
}
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
def create_unified_multi_text_loaders(
|
| 178 |
-
unified_embeddings_path,
|
| 179 |
-
batch_size=64,
|
| 180 |
-
train_ratio=0.8,
|
| 181 |
-
pin_memory=True,
|
| 182 |
-
text_sampling_strategy="random",
|
| 183 |
-
num_workers=4,
|
| 184 |
-
max_train_samples=None,
|
| 185 |
-
num_embedding=None,
|
| 186 |
-
**kwargs
|
| 187 |
-
):
|
| 188 |
-
"""
|
| 189 |
-
Create train and validation data loaders for multi-text training from unified parquet.
|
| 190 |
-
|
| 191 |
-
Args:
|
| 192 |
-
unified_embeddings_path: Path to unified parquet file
|
| 193 |
-
batch_size: Batch size for training
|
| 194 |
-
train_ratio: Fraction of samples for training
|
| 195 |
-
pin_memory: Whether to pin memory for GPU transfer
|
| 196 |
-
text_sampling_strategy: How to sample text embeddings ("original", "summaries-only", "specific_summary", "random", "round-robin", "weighted")
|
| 197 |
-
num_workers: Number of data loading workers
|
| 198 |
-
max_train_samples: Maximum number of training samples (for data scaling experiments)
|
| 199 |
-
num_embedding: When using "specific_summary" strategy, the index of the embedding to use
|
| 200 |
-
**kwargs: Additional arguments
|
| 201 |
-
"""
|
| 202 |
-
|
| 203 |
-
# Convert to Path
|
| 204 |
-
parquet_path = Path(unified_embeddings_path)
|
| 205 |
-
|
| 206 |
-
if not parquet_path.exists():
|
| 207 |
-
raise ValueError(f"Unified embeddings file not found: {parquet_path}")
|
| 208 |
-
|
| 209 |
-
logger.info(f"Creating unified multi-text data loaders from {parquet_path}")
|
| 210 |
-
logger.info(f"Batch size: {batch_size}, Workers: {num_workers}")
|
| 211 |
-
logger.info(f"Text sampling strategy: {text_sampling_strategy}")
|
| 212 |
-
|
| 213 |
-
# Create datasets
|
| 214 |
-
train_dataset = UnifiedMultiTextDataset(
|
| 215 |
-
parquet_path=parquet_path,
|
| 216 |
-
split="train",
|
| 217 |
-
train_ratio=train_ratio,
|
| 218 |
-
text_sampling_strategy=text_sampling_strategy,
|
| 219 |
-
max_train_samples=max_train_samples,
|
| 220 |
-
num_embedding=num_embedding
|
| 221 |
-
)
|
| 222 |
-
|
| 223 |
-
val_dataset = UnifiedMultiTextDataset(
|
| 224 |
-
parquet_path=parquet_path,
|
| 225 |
-
split="val",
|
| 226 |
-
train_ratio=train_ratio,
|
| 227 |
-
text_sampling_strategy=text_sampling_strategy,
|
| 228 |
-
num_embedding=num_embedding
|
| 229 |
-
)
|
| 230 |
-
|
| 231 |
-
# Create loaders
|
| 232 |
-
train_loader = DataLoader(
|
| 233 |
-
train_dataset,
|
| 234 |
-
batch_size=batch_size,
|
| 235 |
-
shuffle=True, # Shuffle within the train split
|
| 236 |
-
num_workers=num_workers,
|
| 237 |
-
pin_memory=pin_memory,
|
| 238 |
-
drop_last=True # Drop incomplete batches for stable training
|
| 239 |
-
)
|
| 240 |
-
|
| 241 |
-
val_loader = DataLoader(
|
| 242 |
-
val_dataset,
|
| 243 |
-
batch_size=batch_size,
|
| 244 |
-
shuffle=False, # No shuffle for validation
|
| 245 |
-
num_workers=num_workers,
|
| 246 |
-
pin_memory=pin_memory,
|
| 247 |
-
drop_last=False
|
| 248 |
-
)
|
| 249 |
-
|
| 250 |
-
return train_loader, val_loader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clip/utils/io_utils.py
DELETED
|
@@ -1,103 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
I/O utilities for saving and loading CLIP embeddings.
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import h5py
|
| 6 |
-
import numpy as np
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
from datetime import datetime
|
| 9 |
-
import logging
|
| 10 |
-
|
| 11 |
-
logger = logging.getLogger(__name__)
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def save_clip_embeddings_hdf5(
|
| 15 |
-
object_ids,
|
| 16 |
-
galaxy_data,
|
| 17 |
-
text_data,
|
| 18 |
-
aion_clip_embeddings,
|
| 19 |
-
text_clip_embeddings,
|
| 20 |
-
output_dir="data/processed"
|
| 21 |
-
):
|
| 22 |
-
"""Save CLIP embeddings to separate HDF5 files."""
|
| 23 |
-
output_dir = Path(output_dir)
|
| 24 |
-
output_dir.mkdir(parents=True, exist_ok=True)
|
| 25 |
-
|
| 26 |
-
# File paths (standardized names)
|
| 27 |
-
aion_clip_path = output_dir / "galaxy_aion_clip_embeddings.hdf5"
|
| 28 |
-
text_clip_path = output_dir / "galaxy_text_clip_embeddings.hdf5"
|
| 29 |
-
|
| 30 |
-
logger.info(f"Saving AION CLIP embeddings to: {aion_clip_path}")
|
| 31 |
-
|
| 32 |
-
# Save AION CLIP embeddings
|
| 33 |
-
with h5py.File(aion_clip_path, 'w') as f:
|
| 34 |
-
# Object IDs
|
| 35 |
-
dt = h5py.special_dtype(vlen=str)
|
| 36 |
-
f.create_dataset('object_id', data=[str(oid) for oid in object_ids], dtype=dt)
|
| 37 |
-
|
| 38 |
-
# Coordinates and metadata
|
| 39 |
-
ra_values = np.array([galaxy_data[oid]['ra'] for oid in object_ids])
|
| 40 |
-
dec_values = np.array([galaxy_data[oid]['dec'] for oid in object_ids])
|
| 41 |
-
healpix_values = np.array([galaxy_data[oid]['healpix'] for oid in object_ids])
|
| 42 |
-
|
| 43 |
-
f.create_dataset('ra', data=ra_values, dtype=np.float64)
|
| 44 |
-
f.create_dataset('dec', data=dec_values, dtype=np.float64)
|
| 45 |
-
f.create_dataset('healpix', data=healpix_values, dtype=np.int64)
|
| 46 |
-
|
| 47 |
-
# AION CLIP embeddings
|
| 48 |
-
f.create_dataset('AION_clip_embedding', data=aion_clip_embeddings, dtype=np.float32)
|
| 49 |
-
|
| 50 |
-
# Metadata
|
| 51 |
-
f.attrs['description'] = 'AION embeddings encoded through trained CLIP model'
|
| 52 |
-
f.attrs['embedding_dim'] = aion_clip_embeddings.shape[1]
|
| 53 |
-
f.attrs['num_objects'] = len(object_ids)
|
| 54 |
-
f.attrs['created'] = datetime.now().isoformat()
|
| 55 |
-
|
| 56 |
-
logger.info(f"Saving text CLIP embeddings to: {text_clip_path}")
|
| 57 |
-
|
| 58 |
-
# Save text CLIP embeddings
|
| 59 |
-
with h5py.File(text_clip_path, 'w') as f:
|
| 60 |
-
# Object IDs
|
| 61 |
-
dt = h5py.special_dtype(vlen=str)
|
| 62 |
-
f.create_dataset('object_id', data=[str(oid) for oid in object_ids], dtype=dt)
|
| 63 |
-
|
| 64 |
-
# Coordinates and metadata (use text data for consistency)
|
| 65 |
-
ra_values = np.array([text_data[oid]['ra'] for oid in object_ids])
|
| 66 |
-
dec_values = np.array([text_data[oid]['dec'] for oid in object_ids])
|
| 67 |
-
healpix_values = np.array([text_data[oid]['healpix'] for oid in object_ids])
|
| 68 |
-
|
| 69 |
-
f.create_dataset('ra', data=ra_values, dtype=np.float64)
|
| 70 |
-
f.create_dataset('dec', data=dec_values, dtype=np.float64)
|
| 71 |
-
f.create_dataset('healpix', data=healpix_values, dtype=np.int64)
|
| 72 |
-
|
| 73 |
-
# Text CLIP embeddings
|
| 74 |
-
f.create_dataset('text_clip_embedding', data=text_clip_embeddings, dtype=np.float32)
|
| 75 |
-
|
| 76 |
-
# Metadata
|
| 77 |
-
f.attrs['description'] = 'Text embeddings encoded through trained CLIP model'
|
| 78 |
-
f.attrs['embedding_dim'] = text_clip_embeddings.shape[1]
|
| 79 |
-
f.attrs['num_objects'] = len(object_ids)
|
| 80 |
-
f.attrs['created'] = datetime.now().isoformat()
|
| 81 |
-
|
| 82 |
-
return aion_clip_path, text_clip_path
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
def inspect_generated_files(aion_clip_path, text_clip_path):
|
| 86 |
-
"""Inspect the generated HDF5 files."""
|
| 87 |
-
logger.info("Inspecting generated AION CLIP embeddings file...")
|
| 88 |
-
|
| 89 |
-
with h5py.File(aion_clip_path, 'r') as f:
|
| 90 |
-
logger.info(f"AION file datasets: {list(f.keys())}")
|
| 91 |
-
for key in f.keys():
|
| 92 |
-
dataset = f[key]
|
| 93 |
-
logger.info(f" {key}: shape={dataset.shape}, dtype={dataset.dtype}")
|
| 94 |
-
logger.info(f" Attributes: {dict(f.attrs)}")
|
| 95 |
-
|
| 96 |
-
logger.info("Inspecting generated text CLIP embeddings file...")
|
| 97 |
-
|
| 98 |
-
with h5py.File(text_clip_path, 'r') as f:
|
| 99 |
-
logger.info(f"Text file datasets: {list(f.keys())}")
|
| 100 |
-
for key in f.keys():
|
| 101 |
-
dataset = f[key]
|
| 102 |
-
logger.info(f" {key}: shape={dataset.shape}, dtype={dataset.dtype}")
|
| 103 |
-
logger.info(f" Attributes: {dict(f.attrs)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
clip/utils/logging_utils.py
DELETED
|
@@ -1,42 +0,0 @@
|
|
| 1 |
-
"""Logging utilities."""
|
| 2 |
-
|
| 3 |
-
import logging
|
| 4 |
-
import sys
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
def setup_logging(log_level: str = "INFO", log_file: str = None):
|
| 9 |
-
"""
|
| 10 |
-
Setup logging configuration.
|
| 11 |
-
|
| 12 |
-
Args:
|
| 13 |
-
log_level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
| 14 |
-
log_file: Optional path to log file
|
| 15 |
-
"""
|
| 16 |
-
# Clear any existing handlers
|
| 17 |
-
logging.getLogger().handlers.clear()
|
| 18 |
-
|
| 19 |
-
# Create formatter
|
| 20 |
-
formatter = logging.Formatter(
|
| 21 |
-
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 22 |
-
)
|
| 23 |
-
|
| 24 |
-
# Console handler
|
| 25 |
-
console_handler = logging.StreamHandler(sys.stdout)
|
| 26 |
-
console_handler.setFormatter(formatter)
|
| 27 |
-
|
| 28 |
-
# Setup root logger
|
| 29 |
-
logger = logging.getLogger()
|
| 30 |
-
logger.setLevel(getattr(logging, log_level.upper()))
|
| 31 |
-
logger.addHandler(console_handler)
|
| 32 |
-
|
| 33 |
-
# File handler if specified
|
| 34 |
-
if log_file:
|
| 35 |
-
log_path = Path(log_file)
|
| 36 |
-
log_path.parent.mkdir(parents=True, exist_ok=True)
|
| 37 |
-
|
| 38 |
-
file_handler = logging.FileHandler(log_path)
|
| 39 |
-
file_handler.setFormatter(formatter)
|
| 40 |
-
logger.addHandler(file_handler)
|
| 41 |
-
|
| 42 |
-
return logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main.py
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
def main():
|
| 2 |
-
print("Hello from aion-search!")
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
if __name__ == "__main__":
|
| 6 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|