""" Sample embeddings + reconstructions from a trained AstroPT model (AIM-compatible) DETERMINISTIC / NO-SHUFFLE VERSION: - Pulls from HuggingFace "Smith42/galaxies" using streaming=True - DOES NOT shuffle at all - Iterates in the exact order HuggingFace yields examples (shard/file order) - Generates 12 reconstructions FIRST using the first 12 examples of `recon_split` - Then extracts embeddings from BOTH `test` + `validation` in order and saves .npy - Saves IDs from HF column `dr8_id` in a separate file: - idxs_...npy (dtype=str/object), aligned 1:1 with embeddings rows - Avoids concatenate_datasets() for streaming datasets (prevents PyArrow crash with torch tensors) Notes: - If you want the strictest “exact order” behavior, set num_workers = 0. - The recon figure uses validate()-style rendering (prepend zero token + optional antispiralise). """ import os import math import functools from contextlib import nullcontext from typing import Dict, Any, List, Optional import numpy as np import torch from torch.utils.data import DataLoader from torchvision import transforms import matplotlib.pyplot as plt import einops from tqdm import tqdm from datasets import load_dataset from astropt.model import GPT, GPTConfig from astropt.local_datasets import GalaxyImageDataset # ----------------------------------------------------------------------------- # Config # ----------------------------------------------------------------------------- # checkpoint/log dir out_dir = "/mnt/c/Users/shaha/Downloads/AIM" # where ckpt.pt lives + where outputs will be written # compute device = "cuda" dtype = "bfloat16" compile = False # HF dataset dataset_name = "Smith42/galaxies" stream_hf_dataset = True # keep streaming revision = None # set e.g. "v2.0" if you need a pinned revision # Embedding extraction splits_for_embeddings = ["test", "validation"] # as requested batch_size = 256 num_workers = 0 # set to 0 for strictest order guarantee pin_memory = True prefix_len = 64 # number of image-tokens used for embeddings embed_reduction = "mean" # "mean" | "last" | "exp_decay" | "none" # Reconstruction figure n_recon = 12 recon_split = "test" save_recon_name = "recon_12.png" # If your HF dataset uses 'dr8_id' (string), we save it. If missing, we store "-1". id_field_name = "dr8_id" # ----------------------------------------------------------------------------- # Transforms (match training) # ----------------------------------------------------------------------------- def normalise(x: torch.Tensor, use_hf: bool = False) -> torch.Tensor: # HF is in numpy format. Need to change that here if so: if use_hf and isinstance(x, np.ndarray): x = torch.from_numpy(x).to(torch.float32) std, mean = torch.std_mean(x, dim=1, keepdim=True) x_norm = (x - mean) / (std + 1e-8) return x_norm.to(torch.float16) def data_transforms(use_hf: bool): return transforms.Compose( [ transforms.Lambda(functools.partial(normalise, use_hf=use_hf)), ] ) def process_galaxy_wrapper(galdict: Dict[str, Any], func): """Wrapper for processing galaxy images from HF dataset (MATCH TRAINING).""" patch_galaxy = func(np.array(galdict["image"]).swapaxes(0, 2)) return { "images": patch_galaxy.to(torch.float), "images_positions": torch.arange(0, len(patch_galaxy), dtype=torch.long), # keep ID as string if present id_field_name: galdict.get(id_field_name, "-1"), } # ----------------------------------------------------------------------------- # Model loading # ----------------------------------------------------------------------------- def load_model(out_dir: str, device: str, dtype: str, compile_model: bool): ckpt_path = os.path.join(out_dir, "ckpt.pt") if not os.path.exists(ckpt_path): raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False) modality_registry = checkpoint["modality_registry"] gptconf = GPTConfig(**checkpoint["model_args"]) model = GPT(gptconf, modality_registry) state_dict = checkpoint["model"] unwanted_prefix = "_orig_mod." for k, v in list(state_dict.items()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) model.load_state_dict(state_dict) model.eval().to(device) if compile_model: model = torch.compile(model) device_type = "cuda" if "cuda" in device else "cpu" ptdtype = { "float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16, }[dtype] ctx = ( nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=ptdtype) ) return model, modality_registry, checkpoint, ctx # ----------------------------------------------------------------------------- # HF dataset builder (match training pipeline) # ----------------------------------------------------------------------------- def build_hf_stream( dataset_name: str, split: str, galproc: GalaxyImageDataset, streaming: bool, revision: Optional[str] = None, ): kwargs = dict(split=split, streaming=streaming) if revision is not None: kwargs["revision"] = revision ds = load_dataset(dataset_name, **kwargs) # IMPORTANT: keep id_field_name; only select image_crop for processing, but do NOT drop other columns # We do this by selecting both "image_crop" and id_field_name if present. # However, with streaming datasets, select_columns() will error if a column doesn't exist, # so we handle it by trying and falling back. try: ds = ds.select_columns(["image_crop", id_field_name]) except Exception: # If id_field_name doesn't exist, just select image_crop ds = ds.select_columns(["image_crop"]) ds = ds.rename_column("image_crop", "image") ds = ds.map(functools.partial(process_galaxy_wrapper, func=galproc.process_galaxy)) # process_galaxy_wrapper returns "images", "images_positions", and id_field_name # remove raw "image" to keep batches light ds = ds.remove_columns("image") return ds # ----------------------------------------------------------------------------- # validate()-style token->image conversion (shift + spiral handling) # ----------------------------------------------------------------------------- def tokens_to_images_validate_style( tokens: torch.Tensor, # [B, T, patch_dim] galproc: GalaxyImageDataset, spiral: bool, patch_size: int, image_size: int, ) -> np.ndarray: """ Mirrors validate() behavior: - prepend a zero_block at t=0 - if spiral: antispiralise per sample - rearrange patches into [B, H, W, C] """ B, T, D = tokens.shape zero_block = torch.zeros((B, 1, D), device=tokens.device, dtype=tokens.dtype) tok = torch.cat((zero_block, tokens), dim=1) # [B, T+1, D] if spiral: tok = torch.stack([galproc.antispiralise(yy) for yy in tok]) # Drop the prepended token so we can render a clean h*w grid tok = tok[:, 1:, :] # [B, h*w, D] n_chan = D // (patch_size * patch_size) if (patch_size * patch_size * n_chan) != D: raise RuntimeError( f"Cannot factor patch_dim={D} into patch_size^2 * n_chan (patch_size={patch_size})." ) h = image_size // patch_size w = image_size // patch_size if h * w != tok.size(1): raise RuntimeError( f"Token count {tok.size(1)} != h*w ({h}*{w}={h*w}). " f"image_size={image_size}, patch_size={patch_size}." ) img = einops.rearrange( tok, "b (hh ww) (p1 p2 c) -> b (hh p1) (ww p2) c", p1=patch_size, p2=patch_size, hh=h, ww=w, c=n_chan, ) return img.to(torch.float32).detach().cpu().numpy() # ----------------------------------------------------------------------------- # Main # ----------------------------------------------------------------------------- if __name__ == "__main__": os.makedirs(out_dir, exist_ok=True) model, modality_registry, checkpoint, ctx = load_model(out_dir, device, dtype, compile) # ------------------------------------------------------------------------- # Match training config when available # ------------------------------------------------------------------------- train_cfg = checkpoint.get("config", {}) if isinstance(checkpoint, dict) else {} spiral = bool(train_cfg.get("spiral", True)) block_size = int(train_cfg.get("block_size", checkpoint["model_args"].get("block_size", 1024))) # Clamp prefix_len to block_size for safety prefix_len = int(min(prefix_len, block_size)) # Build GalaxyImageDataset processor (same tokenization/spiral ops as training) transforms_map = {"images": data_transforms(use_hf=True)} galproc = GalaxyImageDataset( paths=None, spiral=spiral, transform=transforms_map, modality_registry=modality_registry, ) # Modality params img_cfg = modality_registry.get_config("images") patch_size = int(img_cfg.patch_size) # ------------------------------------------------------------------------- # [1/2] Reconstructions FIRST (NO SHUFFLE => first 12 items in HF order) # ------------------------------------------------------------------------- print("\n[1/2] Creating 12-image reconstruction panel (orig left, recon right)...") ds_recon = build_hf_stream( dataset_name=dataset_name, split=recon_split, galproc=galproc, streaming=stream_hf_dataset, revision=revision, ) # Pull first 12 samples in HF order samples: List[Dict[str, Any]] = [] it = iter(ds_recon) while len(samples) < n_recon: samples.append(next(it)) # Stack into a batch: images [B,T,D], positions [B,T] images = torch.stack([s["images"] for s in samples], dim=0).to(device) positions = torch.stack([s["images_positions"] for s in samples], dim=0).to(device) # Infer image_size from token length T = (H/patch)*(W/patch) T = images.size(1) side = int(round(math.sqrt(T))) if side * side != T: raise RuntimeError(f"Token length T={T} is not a perfect square; cannot infer image_size cleanly.") image_size = side * patch_size # Teacher-forced forward: P, loss = model(X, targets=Y) X = {"images": images, "images_positions": positions} Y = {"images": images, "images_positions": positions} with torch.no_grad(): with ctx: P, loss = model(X, targets=Y) # validate()-style visualization Y_img = tokens_to_images_validate_style( tokens=Y["images"], galproc=galproc, spiral=spiral, patch_size=patch_size, image_size=image_size, ) P_img = tokens_to_images_validate_style( tokens=P["images"], galproc=galproc, spiral=spiral, patch_size=patch_size, image_size=image_size, ) fig, axs = plt.subplots(n_recon, 2, figsize=(6, 3 * n_recon), constrained_layout=True) if n_recon == 1: axs = np.array([axs]) for i in range(n_recon): axs[i, 0].imshow(np.clip(Y_img[i], 0, 1)) axs[i, 0].axis("off") axs[i, 0].set_title("Original") axs[i, 1].imshow(np.clip(P_img[i], 0, 1)) axs[i, 1].axis("off") axs[i, 1].set_title("Reconstructed") recon_path = os.path.join(out_dir, save_recon_name) fig.savefig(recon_path, dpi=150) plt.close(fig) print(f"Saved recon panel: {recon_path}") # ------------------------------------------------------------------------- # [2/2] Embedding extraction SECOND (NO SHUFFLE, NO CONCAT for streaming) # ------------------------------------------------------------------------- print("\n[2/2] Extracting embeddings (test + validation) and saving .npy...") zss_chunks = [] ids_chunks = [] with torch.no_grad(): with ctx: for split in splits_for_embeddings: print(f" -> split: {split}") ds_embed = build_hf_stream( dataset_name=dataset_name, split=split, galproc=galproc, streaming=stream_hf_dataset, revision=revision, ) dl = DataLoader( ds_embed, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, ) tt = tqdm(total=None, unit="galaxies", unit_scale=True) for B in dl: xs = B["images"][:, :prefix_len].to(device) pos = B["images_positions"][:, :prefix_len].to(device) inputs = {"images": xs, "images_positions": pos} zs = model.generate_embeddings(inputs, reduction=embed_reduction) zss_chunks.append(zs["images"].detach().cpu().numpy()) # Save IDs (strings) aligned with embeddings if id_field_name in B: # B[id_field_name] can be list[str] or np array depending on HF formatting ids = np.array(B[id_field_name], dtype=object) else: ids = np.array(["-1"] * xs.size(0), dtype=object) ids_chunks.append(ids) tt.update(xs.size(0)) tt.close() zss = np.concatenate(zss_chunks, axis=0) ids = np.concatenate(ids_chunks, axis=0) emb_path = os.path.join(out_dir, f"zss_{prefix_len}t_{embed_reduction}.npy") ids_path = os.path.join(out_dir, f"idxs_{prefix_len}t_{embed_reduction}.npy") np.save(emb_path, zss) np.save(ids_path, ids) print(f"Saved embeddings: {zss.shape}") print(f" - {emb_path}") print(f"Saved ids: {ids.shape} (dtype={ids.dtype})") print(f" - {ids_path}") print("Done.")