| """ |
| JetFormer Embedding Extractor + Reconstruction Sampler |
| |
| - Imports model definitions from 'train_jetformer_sogol.py' |
| - Generates a 12-image reconstruction panel (Original vs Model Prediction). |
| - Extracts Transformer embeddings (h) from 'test' and 'validation' splits. |
| - Concatenates both splits into single output files (zss...npy and idxs...npy). |
| """ |
|
|
| import os |
| import numpy as np |
| import torch |
| from torch.utils.data import DataLoader |
| from torchvision import transforms |
| import matplotlib.pyplot as plt |
| from tqdm import tqdm |
| from datasets import load_dataset |
| from PIL import ImageFile |
|
|
| |
| |
| from train_jetformer_sogol import JetFormer, CFG, uniform_dequantize, patchify, depatchify |
|
|
| |
| ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
| |
| |
| |
| |
| out_dir = "/mnt/c/Users/shaha/Downloads/sogol" |
| checkpoint_path = "/mnt/c/Users/shaha/Downloads/sogol/sogol_checkpoint_step_0079999.pt" |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| batch_size = 128 |
| num_workers = 0 |
|
|
| |
| dataset_name = "Smith42/galaxies" |
| stream_hf_dataset = True |
| id_field_name = "dr8_id" |
|
|
| |
| splits_for_embeddings = ["test", "validation"] |
| embed_reduction = "mean" |
| embed_layer_index = 12 |
|
|
|
|
| |
| |
| |
|
|
| def to_float_image(x: torch.Tensor) -> torch.Tensor: |
| """ |
| Convert image to float [0,1] range for embedding extraction. |
| No noise added - we want deterministic, stable embeddings. |
| """ |
| if x.dtype == torch.uint8: |
| return x.float() / 255.0 |
| return x.clamp(0.0, 1.0) |
|
|
| def load_model(path, device): |
| print(f"Loading checkpoint: {path}") |
| cfg = CFG() |
| model = JetFormer(cfg).to(device) |
| |
| ckpt = torch.load(path, map_location=device) |
| state_dict = ckpt['model_state_dict'] if 'model_state_dict' in ckpt else ckpt['model'] |
| |
| |
| new_state_dict = {} |
| for k, v in state_dict.items(): |
| name = k.replace("_orig_mod.", "").replace("module.", "") |
| new_state_dict[name] = v |
| |
| model.load_state_dict(new_state_dict) |
| model.eval() |
| return model, cfg |
|
|
| def get_transformer_embeddings(model, x: torch.Tensor, layer_index: int = None): |
| """ |
| Extracts stable, deterministic hidden state (h) from the transformer at the given layer (core). |
| Uses output after that many blocks; no final layer norm. Default: embed_layer_index (layer 6). |
| |
| For stable embeddings: |
| - No noise added (simple float conversion to [0,1]) |
| - Model in eval mode (disables dropout, batch norm, training noise) |
| - Extracts from core transformer layer (layer 6 by default) |
| """ |
| if layer_index is None: |
| layer_index = embed_layer_index |
| |
| |
| model.eval() |
| |
| |
| x_in = to_float_image(x) |
| tokens_in = patchify(x_in, model.cfg.patch) |
|
|
| |
| z, _ = model.flow(tokens_in, reverse=False) |
|
|
| |
| h = model.in_proj(z) + model.pos |
|
|
| |
| |
| |
| for i in range(layer_index): |
| h = model.gpt.h[i](h) |
| return h |
|
|
| def generate_reconstruction(model, x_real_batch: torch.Tensor): |
| """ |
| Generative reconstruction: Image -> z -> GPT -> Predict Next z -> Image |
| """ |
| model.eval() |
| x_real = x_real_batch |
| x_real_proc = uniform_dequantize(x_real) |
|
|
| |
| z_real, _ = model.flow(patchify(x_real_proc, model.cfg.patch), reverse=False) |
| |
| |
| h_in = model.in_proj(z_real) + model.pos |
| h_out = model.gpt(h_in) |
| logits_pi, mu, _ = model.head(h_out) |
|
|
| |
| best_comp_idx = torch.argmax(logits_pi, dim=-1, keepdim=True) |
| gather_idx = best_comp_idx.unsqueeze(-1).expand(-1, -1, -1, model.cfg.d_token) |
| z_pred_next = torch.gather(mu, 2, gather_idx).squeeze(2) |
|
|
| |
| z_rec = torch.zeros_like(z_real) |
| z_rec[:, 0] = z_real[:, 0] |
| z_rec[:, 1:] = z_pred_next[:, :-1] |
|
|
| |
| x_rec_tokens, _ = model.flow(z_rec, reverse=True) |
| x_rec = depatchify(x_rec_tokens, model.cfg.in_ch, model.cfg.img_size, model.cfg.img_size, model.cfg.patch) |
|
|
| return x_real, x_rec.clamp(0, 1) |
|
|
| def process_hf_item(item): |
| img = item['image_crop'] |
| to_tensor = transforms.ToTensor() |
| img_t = to_tensor(img) |
| if img_t.shape[0] == 1: |
| img_t = img_t.repeat(3, 1, 1) |
| |
| |
| img_id = str(item.get(id_field_name, "-1")) |
| return {"img": img_t, "id": img_id} |
|
|
|
|
| |
| |
| |
| if __name__ == "__main__": |
| os.makedirs(out_dir, exist_ok=True) |
| |
| |
| |
| |
| model, cfg = load_model(checkpoint_path, device) |
| |
| print("\n[1/2] Creating 12-image reconstruction panel...") |
| |
| |
| ds_recon = load_dataset(dataset_name, split="test", streaming=stream_hf_dataset) |
| ds_recon = ds_recon.map(process_hf_item, remove_columns=["image", "image_crop", "survey", "ra", "dec"]) |
| |
| recon_imgs = [] |
| it = iter(ds_recon) |
| for _ in range(12): |
| recon_imgs.append(next(it)['img']) |
| batch_recon = torch.stack(recon_imgs).to(device) |
| |
| with torch.no_grad(): |
| origs, recons = generate_reconstruction(model, batch_recon) |
| |
| origs = origs.cpu().permute(0, 2, 3, 1).numpy() |
| recons = recons.cpu().permute(0, 2, 3, 1).numpy() |
| |
| fig, axs = plt.subplots(12, 2, figsize=(6, 36), constrained_layout=True) |
| for i in range(12): |
| axs[i, 0].imshow(np.clip(origs[i], 0, 1)) |
| axs[i, 0].axis("off") |
| axs[i, 0].set_title("Original" if i==0 else "") |
| axs[i, 1].imshow(np.clip(recons[i], 0, 1)) |
| axs[i, 1].axis("off") |
| axs[i, 1].set_title("Reconstructed" if i==0 else "") |
| |
| recon_path = os.path.join(out_dir, "reconstruction_panel.png") |
| fig.savefig(recon_path, dpi=150) |
| plt.close(fig) |
| print(f"Saved recon panel: {recon_path}") |
|
|
| |
| |
| |
| print("\n[2/2] Extracting Embeddings...") |
| |
| |
| master_zss_list = [] |
| master_ids_list = [] |
| |
| for split in splits_for_embeddings: |
| print(f" -> Processing split: {split}") |
| ds = load_dataset(dataset_name, split=split, streaming=stream_hf_dataset) |
| ds = ds.map(process_hf_item, remove_columns=["image", "image_crop", "survey", "ra", "dec"]) |
| |
| dl = DataLoader(ds, batch_size=batch_size, num_workers=num_workers) |
| pbar = tqdm(dl, desc=f"Extracting {split}") |
| |
| with torch.no_grad(): |
| for batch in pbar: |
| imgs = batch['img'].to(device) |
| ids = batch['id'] |
| |
| h = get_transformer_embeddings(model, imgs) |
| |
| if embed_reduction == "mean": |
| emb = h.mean(dim=1) |
| elif embed_reduction == "last": |
| emb = h[:, -1, :] |
| else: |
| emb = h |
| |
| master_zss_list.append(emb.float().cpu().numpy()) |
| master_ids_list.extend(ids) |
|
|
| |
| print("\nConcatenating all splits...") |
| final_zss = np.concatenate(master_zss_list, axis=0) |
| final_ids = np.array(master_ids_list) |
| |
| |
| zss_path = os.path.join(out_dir, f"zss_combined_layer{embed_layer_index}_deterministic_{embed_reduction}.npy") |
| ids_path = os.path.join(out_dir, "idxs_combined.npy") |
| |
| np.save(zss_path, final_zss) |
| np.save(ids_path, final_ids) |
| |
| print(f"Saved Embeddings: {final_zss.shape}") |
| print(f" -> {zss_path}") |
| print(f"Saved IDs: {final_ids.shape}") |
| print(f" -> {ids_path}") |
| print("Done.") |