Spaces:
Sleeping
Sleeping
| # src/evaluation/eval_tsne_umap.py | |
| import argparse | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms as T, models | |
| from tqdm import tqdm | |
| import matplotlib.pyplot as plt | |
| from sklearn.manifold import TSNE | |
| # Reuse your test dataset loader from eval_accuracy | |
| from src.evaluation.eval_accuracy import load_test_dataset | |
| # Optional UMAP support | |
| try: | |
| import umap | |
| HAS_UMAP = True | |
| except ImportError: | |
| HAS_UMAP = False | |
| print("[INFO] umap-learn not installed; will skip UMAP and only run t-SNE.") | |
| class ResNetFeatureExtractor(nn.Module): | |
| """ | |
| Wraps a torchvision ResNet18 pretrained on ImageNet and | |
| exposes a 512-d feature vector for each image. | |
| """ | |
| def __init__(self, device="cuda"): | |
| super().__init__() | |
| # Use the modern weights API | |
| backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) | |
| # Remove the final FC layer: keep everything up to avgpool | |
| self.feature_extractor = nn.Sequential(*list(backbone.children())[:-1]) | |
| self.feature_extractor.to(device) | |
| self.feature_extractor.eval() | |
| self.device = device | |
| # Standard ImageNet normalization | |
| self.transform = T.Compose([ | |
| T.Resize((224, 224)), | |
| T.ToTensor(), | |
| T.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225], | |
| ), | |
| ]) | |
| def forward(self, pil_img): | |
| """ | |
| pil_img: a single PIL.Image | |
| returns: numpy array of shape (512,) | |
| """ | |
| x = self.transform(pil_img).unsqueeze(0).to(self.device) # (1, 3, 224, 224) | |
| feat = self.feature_extractor(x) # (1, 512, 1, 1) | |
| feat = feat.view(1, -1) # (1, 512) | |
| return feat.squeeze(0).cpu().numpy() | |
| def extract_features(data_root: str, max_samples: int = 2000, seed: int = 42): | |
| """ | |
| Extract: | |
| - Raw 64x64 grayscale flattened features (for LR/SVM-style space) | |
| - ResNet18 pretrained 512-d features | |
| Returns: | |
| X_raw : (N, 4096) | |
| X_resnet: (N, 512) | |
| y : (N,) | |
| """ | |
| print(f"[INFO] Loading test dataset from {data_root}") | |
| dataset = load_test_dataset(data_root) | |
| total = len(dataset) | |
| # Optional subsampling for t-SNE / UMAP visualization | |
| rng = np.random.default_rng(seed) | |
| if max_samples is not None and max_samples < total: | |
| indices = rng.choice(total, size=max_samples, replace=False) | |
| indices = sorted(indices.tolist()) | |
| print(f"[INFO] Subsampling {len(indices)} / {total} test samples for visualization.") | |
| else: | |
| indices = list(range(total)) | |
| print(f"[INFO] Using all {total} test samples for visualization.") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"[INFO] Using device: {device}") | |
| # Raw feature pipeline: 64x64 grayscale + flatten | |
| raw_transform = T.Compose([ | |
| T.Resize((64, 64)), | |
| T.Grayscale(num_output_channels=1), | |
| T.ToTensor(), # (1, 64, 64), values in [0,1] | |
| ]) | |
| resnet_extractor = ResNetFeatureExtractor(device=device) | |
| X_raw_list = [] | |
| X_resnet_list = [] | |
| y_list = [] | |
| for idx in tqdm(indices, desc="Extracting features"): | |
| img, target = dataset[idx] # img: PIL.Image, target: int | |
| y_list.append(int(target)) | |
| # Raw features | |
| raw_tensor = raw_transform(img) # (1, 64, 64) | |
| X_raw_list.append(raw_tensor.view(-1).numpy()) # (4096,) | |
| # ResNet features | |
| resnet_feat = resnet_extractor(img) # (512,) | |
| X_resnet_list.append(resnet_feat) | |
| X_raw = np.stack(X_raw_list, axis=0) # (N, 4096) | |
| X_resnet = np.stack(X_resnet_list, axis=0) # (N, 512) | |
| y = np.array(y_list, dtype=int) | |
| print(f"[INFO] X_raw shape: {X_raw.shape}") | |
| print(f"[INFO] X_resnet shape: {X_resnet.shape}") | |
| print(f"[INFO] y shape: {y.shape}") | |
| return X_raw, X_resnet, y | |
| def run_tsne(X, y, out_path: Path, title: str, num_classes_to_label: int = 10): | |
| """ | |
| Run t-SNE on feature matrix X and save a 2D scatter plot. | |
| Points are colored by class label. | |
| """ | |
| print(f"[INFO] Running t-SNE for {title} with shape {X.shape}") | |
| tsne = TSNE( | |
| n_components=2, | |
| perplexity=30, | |
| learning_rate="auto", | |
| init="pca", | |
| random_state=42, | |
| ) | |
| X_2d = tsne.fit_transform(X) | |
| # Plot | |
| plt.figure(figsize=(10, 8)) | |
| scatter = plt.scatter( | |
| X_2d[:, 0], | |
| X_2d[:, 1], | |
| c=y, | |
| s=8, | |
| alpha=0.7, | |
| cmap="tab20", | |
| ) | |
| plt.title(title) | |
| plt.xticks([]) | |
| plt.yticks([]) | |
| # Optionally build a legend with a subset of classes to avoid clutter | |
| unique_classes = np.unique(y) | |
| if len(unique_classes) > num_classes_to_label: | |
| chosen = unique_classes[:num_classes_to_label] | |
| else: | |
| chosen = unique_classes | |
| # Create proxy artists for legend | |
| handles = [] | |
| labels = [] | |
| for cls in chosen: | |
| handles.append(plt.Line2D([], [], marker="o", linestyle="", | |
| color=scatter.cmap(scatter.norm(cls)))) | |
| labels.append(f"Class {cls}") | |
| plt.legend(handles, labels, title="Example classes", fontsize=8, loc="best") | |
| plt.tight_layout() | |
| plt.savefig(out_path, dpi=300) | |
| plt.close() | |
| print(f"[INFO] Saved t-SNE plot to {out_path}") | |
| def run_umap(X, y, out_path: Path, title: str, num_classes_to_label: int = 10): | |
| """ | |
| Run UMAP on feature matrix X and save a 2D scatter plot. | |
| Only runs if umap-learn is installed. | |
| """ | |
| if not HAS_UMAP: | |
| print(f"[WARN] UMAP not available; skipping {title}") | |
| return | |
| print(f"[INFO] Running UMAP for {title} with shape {X.shape}") | |
| reducer = umap.UMAP( | |
| n_components=2, | |
| n_neighbors=15, | |
| min_dist=0.1, | |
| random_state=42, | |
| ) | |
| X_2d = reducer.fit_transform(X) | |
| plt.figure(figsize=(10, 8)) | |
| scatter = plt.scatter( | |
| X_2d[:, 0], | |
| X_2d[:, 1], | |
| c=y, | |
| s=8, | |
| alpha=0.7, | |
| cmap="tab20", | |
| ) | |
| plt.title(title) | |
| plt.xticks([]) | |
| plt.yticks([]) | |
| unique_classes = np.unique(y) | |
| if len(unique_classes) > num_classes_to_label: | |
| chosen = unique_classes[:num_classes_to_label] | |
| else: | |
| chosen = unique_classes | |
| handles = [] | |
| labels = [] | |
| for cls in chosen: | |
| handles.append(plt.Line2D([], [], marker="o", linestyle="", | |
| color=scatter.cmap(scatter.norm(cls)))) | |
| labels.append(f"Class {cls}") | |
| plt.legend(handles, labels, title="Example classes", fontsize=8, loc="best") | |
| plt.tight_layout() | |
| plt.savefig(out_path, dpi=300) | |
| plt.close() | |
| print(f"[INFO] Saved UMAP plot to {out_path}") | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--data-root", | |
| type=str, | |
| default="data/oxford-iiit-pet", | |
| help="Root directory of Oxford-IIIT Pet dataset.", | |
| ) | |
| parser.add_argument( | |
| "--out-dir", | |
| type=str, | |
| default="outputs/feature_viz", | |
| help="Directory to save t-SNE/UMAP plots.", | |
| ) | |
| parser.add_argument( | |
| "--max-samples", | |
| type=int, | |
| default=2000, | |
| help="Max number of test samples to subsample for visualization (None = all).", | |
| ) | |
| args = parser.parse_args() | |
| out_dir = Path(args.out_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| # 1) Extract features | |
| X_raw, X_resnet, y = extract_features( | |
| data_root=args.data_root, | |
| max_samples=args.max_samples, | |
| seed=42, | |
| ) | |
| # 2) t-SNE on raw features | |
| tsne_raw_path = out_dir / "tsne_raw.png" | |
| run_tsne(X_raw, y, tsne_raw_path, title="t-SNE: Raw 64x64 Grayscale Features") | |
| # 3) t-SNE on ResNet features | |
| tsne_resnet_path = out_dir / "tsne_resnet.png" | |
| run_tsne(X_resnet, y, tsne_resnet_path, title="t-SNE: ResNet18 Pretrained Features") | |
| # 4) Optional UMAP (if available) | |
| umap_raw_path = out_dir / "umap_raw.png" | |
| run_umap(X_raw, y, umap_raw_path, title="UMAP: Raw 64x64 Grayscale Features") | |
| umap_resnet_path = out_dir / "umap_resnet.png" | |
| run_umap(X_resnet, y, umap_resnet_path, title="UMAP: ResNet18 Pretrained Features") | |
| if __name__ == "__main__": | |
| # Keep torch threads manageable | |
| torch.set_num_threads(4) | |
| main() | |