collins909 commited on
Commit
e00756d
·
verified ·
1 Parent(s): 8c323ac

Upload 4 files

Browse files
run_demo.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ run_demo.py — Self-contained dummy demo of upload_to_hub.py
4
+ ============================================================
5
+ Builds a fake HF deployment package WITHOUT requiring torch or a real
6
+ checkpoint, so you can see exactly what files get uploaded.
7
+
8
+ This demo:
9
+ 1. Creates a dummy checkpoint, args.json, label stats files
10
+ 2. Patches torch import to a stub so upload_to_hub.py can run
11
+ 3. Calls package_model() in dry-run mode
12
+ 4. Lists every file in the package with its purpose
13
+
14
+ Run:
15
+ python run_demo.py
16
+ """
17
+ from __future__ import annotations
18
+
19
+ import json
20
+ import shutil
21
+ import sys
22
+ import types
23
+ from pathlib import Path
24
+
25
+ import numpy as np
26
+
27
+
28
+ # ── §1 Build a torch stub (so upload_to_hub.py can be imported) ───────────
29
+
30
+ class _TorchStub:
31
+ class Tensor:
32
+ def __init__(self, data):
33
+ self._d = np.asarray(data)
34
+ self.shape = self._d.shape
35
+ def numel(self): return int(np.prod(self.shape))
36
+ def clone(self): return self
37
+ def contiguous(self): return self
38
+ @property
39
+ def dtype(self): return _DType()
40
+ @staticmethod
41
+ def load(path, **kw):
42
+ # Simulate loading our dummy checkpoint
43
+ return _DUMMY_CKPT
44
+ @staticmethod
45
+ def save(obj, path):
46
+ # Mimic torch.save — for the .bin fallback path
47
+ with open(path, "wb") as f:
48
+ f.write(b"DUMMY_TORCH_BIN")
49
+
50
+ class _DType:
51
+ @property
52
+ def is_floating_point(self): return True
53
+
54
+
55
+ # Mock checkpoint structure that mirrors a real DDPM checkpoint
56
+ _DUMMY_CKPT = {
57
+ "model_state_dict": {
58
+ "unet.conv.weight": _TorchStub.Tensor(np.zeros((64, 1, 3, 3), dtype=np.float32)),
59
+ "unet.conv.bias": _TorchStub.Tensor(np.zeros(64, dtype=np.float32)),
60
+ "unet.label_emb.weight":_TorchStub.Tensor(np.zeros((64, 2), dtype=np.float32)),
61
+ "unet.label_emb.bias": _TorchStub.Tensor(np.zeros(64, dtype=np.float32)),
62
+ "unet.out.weight": _TorchStub.Tensor(np.zeros((1, 64, 1, 1), dtype=np.float32)),
63
+ "unet.out.bias": _TorchStub.Tensor(np.zeros(1, dtype=np.float32)),
64
+ },
65
+ "ema_shadow": {
66
+ "unet.conv.weight": _TorchStub.Tensor(np.ones((64, 1, 3, 3), dtype=np.float32)*0.01),
67
+ "unet.conv.bias": _TorchStub.Tensor(np.zeros(64, dtype=np.float32)),
68
+ "unet.label_emb.weight":_TorchStub.Tensor(np.zeros((64, 2), dtype=np.float32)),
69
+ "unet.label_emb.bias": _TorchStub.Tensor(np.zeros(64, dtype=np.float32)),
70
+ "unet.out.weight": _TorchStub.Tensor(np.zeros((1, 64, 1, 1), dtype=np.float32)),
71
+ "unet.out.bias": _TorchStub.Tensor(np.zeros(1, dtype=np.float32)),
72
+ },
73
+ "epoch": 100,
74
+ }
75
+
76
+ # Stub safetensors too (writes a fake binary blob)
77
+ class _SafetensorsStub:
78
+ @staticmethod
79
+ def save_file(state_dict, path):
80
+ # Just write a fake header so file exists with realistic size
81
+ # In reality safetensors writes a JSON header + binary tensor data
82
+ total_bytes = sum(t.numel() * 4 for t in state_dict.values())
83
+ with open(path, "wb") as f:
84
+ f.write(b"\x00" * total_bytes)
85
+
86
+
87
+ # ── §2 Set up the dummy project ───────────────────────────────────────────
88
+
89
+ DEMO_ROOT = Path("/tmp/ddpm_hf_demo")
90
+ PROJECT = DEMO_ROOT / "project"
91
+ EXPORT = DEMO_ROOT / "hf_export"
92
+
93
+ if DEMO_ROOT.exists():
94
+ shutil.rmtree(DEMO_ROOT)
95
+ PROJECT.mkdir(parents=True)
96
+ (PROJECT / "checkpoints").mkdir()
97
+
98
+ # Minimal source files (will be copied into the HF package)
99
+ (PROJECT / "diffusion_conditional.py").write_text(
100
+ '"""Stub: our DDPM forward/reverse process implementation."""\n'
101
+ 'import torch.nn as nn\n'
102
+ 'class GaussianDiffusion(nn.Module): ...\n'
103
+ 'class ConditionalDiffusionModel(nn.Module): ...\n'
104
+ )
105
+ (PROJECT / "unet_conditional.py").write_text(
106
+ '"""Stub: our conditional U-Net architecture."""\n'
107
+ 'import torch.nn as nn\n'
108
+ 'class ConditionalUNet(nn.Module): ...\n'
109
+ )
110
+
111
+ # Fake checkpoint (file content doesn't matter — torch.load is stubbed)
112
+ (PROJECT / "checkpoints/best_model.pt").write_bytes(b"DUMMY_CKPT")
113
+
114
+ # Training config
115
+ (PROJECT / "args.json").write_text(json.dumps({
116
+ "image_size": 256, "label_dim": 2,
117
+ "base_channels": 64, "channel_multipliers": [1, 2, 4, 8],
118
+ "attention_levels": [2, 3], "dropout": 0.1,
119
+ "timesteps": 1500, "beta_start": 1e-4, "beta_end": 0.02,
120
+ "schedule_type": "linear", "ddim_steps": 50,
121
+ "epochs": 100, "batch_size": 8, "lr": 2e-4,
122
+ "ema_decay": 0.9999, "seed": 42,
123
+ }, indent=2))
124
+
125
+ # Training labels (for label_mu / label_std extraction)
126
+ labels = np.random.uniform([0.1, 0.6], [0.5, 1.0], (50, 2)).astype(np.float32)
127
+ np.save(PROJECT / "train_labels_LH_2.npy", labels)
128
+
129
+
130
+ # ── §3 Inject stubs into sys.modules and import upload_to_hub ─────────────
131
+
132
+ sys.modules["torch"] = _TorchStub()
133
+ sys.modules["safetensors"] = types.ModuleType("safetensors")
134
+ sys.modules["safetensors.torch"] = _SafetensorsStub()
135
+
136
+ # Also stub huggingface_hub so we don't hit the network
137
+ class _HfStub:
138
+ HfApi = type("HfApi", (), {
139
+ "create_repo": lambda *a, **kw: None,
140
+ "upload_folder": lambda *a, **kw: None,
141
+ })
142
+ login = lambda *a, **kw: None
143
+ sys.modules["huggingface_hub"] = _HfStub()
144
+
145
+ sys.path.insert(0, str(Path(__file__).parent))
146
+ import upload_to_hub
147
+
148
+
149
+ # ── §4 Run package_model() in dry-run mode ────────────────────────────────
150
+
151
+ class FakeArgs:
152
+ checkpoint = str(PROJECT / "checkpoints/best_model.pt")
153
+ training_args = str(PROJECT / "args.json")
154
+ data_dir = str(PROJECT)
155
+ export_dir = str(EXPORT)
156
+ no_ema = False
157
+ repo_id = "demo-user/camels-ddpm-omega-sigma8"
158
+
159
+ print("="*65)
160
+ print(" DDPM -> Hugging Face Hub Packager (DUMMY DEMO)")
161
+ print("="*65)
162
+ folder = upload_to_hub.package_model(FakeArgs())
163
+
164
+
165
+ # ── §5 Verify the result ──────────────────────────────────────────────────
166
+
167
+ print("\n" + "="*65)
168
+ print(" Package verification")
169
+ print("="*65)
170
+
171
+ config = json.loads((folder / "config.json").read_text())
172
+ print("\nconfig.json contents:")
173
+ print(json.dumps(config, indent=2))
174
+
175
+ print(f"\nREADME.md preview (first 50 lines):")
176
+ print("-"*65)
177
+ print("\n".join((folder / "README.md").read_text().splitlines()[:50]))
178
+ print("...")
179
+ print("-"*65)
180
+
181
+ print(f"\nFile listing of {folder}:")
182
+ files = sorted(folder.iterdir())
183
+ print(f"\n{'File':<32} {'Size':>10} Purpose")
184
+ print("-"*75)
185
+ purposes = {
186
+ "config.json": "Architecture hyperparameters (hub-readable)",
187
+ "model.safetensors": "Model weights (EMA preferred)",
188
+ "pytorch_model.bin": "Model weights (fallback if no safetensors)",
189
+ "README.md": "Model card with YAML metadata + usage docs",
190
+ "modeling_ddpm_camels.py": "Self-contained loader for `from_pretrained`",
191
+ "diffusion_conditional.py": "Project file: forward/reverse DDPM process",
192
+ "unet_conditional.py": "Project file: U-Net architecture",
193
+ "inference_example.py": "Standalone demo script for users",
194
+ "requirements.txt": "Pinned Python dependencies",
195
+ ".gitattributes": "Git LFS configuration for large files",
196
+ }
197
+ for f in files:
198
+ sz = f.stat().st_size
199
+ sz_str = f"{sz/1e6:.1f}M" if sz > 1e6 else f"{sz/1e3:.1f}K" if sz > 1e3 else f"{sz}B"
200
+ purpose = purposes.get(f.name, "")
201
+ print(f" {f.name:<30} {sz_str:>10} {purpose}")
202
+
203
+ print(f"\nDemo complete -> {folder}")
204
+ print(f"In a real run, the next step is:")
205
+ print(f" python upload_to_hub.py --checkpoint best_model.pt \\")
206
+ print(f" --training_args args.json \\")
207
+ print(f" --repo_id YOUR_USERNAME/camels-ddpm \\")
208
+ print(f" --private")
sample_config.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "conditional_ddpm_camels",
3
+ "in_channels": 1,
4
+ "out_channels": 1,
5
+ "image_size": 256,
6
+ "label_dim": 2,
7
+ "label_names": [
8
+ "Omega_m",
9
+ "sigma_8"
10
+ ],
11
+ "base_channels": 64,
12
+ "channel_multipliers": [
13
+ 1,
14
+ 2,
15
+ 4,
16
+ 8
17
+ ],
18
+ "attention_levels": [
19
+ 2,
20
+ 3
21
+ ],
22
+ "dropout": 0.1,
23
+ "timesteps": 1500,
24
+ "beta_start": 0.0001,
25
+ "beta_end": 0.02,
26
+ "schedule_type": "linear",
27
+ "ddim_steps_default": 50,
28
+ "framework": "pytorch",
29
+ "library_name": "pytorch",
30
+ "training_meta": {
31
+ "epochs": 100,
32
+ "batch_size": 8,
33
+ "lr": 0.0002,
34
+ "ema_decay": 0.9999,
35
+ "seed": 42
36
+ },
37
+ "label_mu": [
38
+ 0.3308129608631134,
39
+ 0.7831979990005493
40
+ ],
41
+ "label_std": [
42
+ 0.1140434592962265,
43
+ 0.12279357761144638
44
+ ]
45
+ }
sample_inference_example.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference_example.py
3
+ ====================
4
+ Standalone script demonstrating how to use the deployed DDPM model.
5
+ After downloading from the Hub, run:
6
+ python inference_example.py
7
+ """
8
+ import json
9
+ import sys
10
+ from pathlib import Path
11
+
12
+ import matplotlib
13
+ matplotlib.use("Agg")
14
+ import matplotlib.pyplot as plt
15
+ import numpy as np
16
+ import torch
17
+
18
+ # Ensure local imports resolve
19
+ sys.path.insert(0, str(Path(__file__).parent))
20
+
21
+ from modeling_ddpm_camels import load_pretrained, generate
22
+
23
+ # ── Configuration ──────────────────────────────────────────────────────────
24
+ MODEL_DIR = Path(__file__).parent
25
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
+
27
+ # ── Load ───────────────────────────────────────────────────────────────────
28
+ print(f"Loading model from {MODEL_DIR} on {DEVICE} ...")
29
+ model, config = load_pretrained(MODEL_DIR, device=DEVICE)
30
+ print(f" Image size: {config[\"image_size\"]}")
31
+ print(f" Label dim: {config[\"label_dim\"]} ({config[\"label_names\"]})")
32
+
33
+ # ── Generate at 4 cosmologies ──────────────────────────────────────────────
34
+ raw_labels = torch.tensor([
35
+ [0.20, 0.95],
36
+ [0.30, 0.80],
37
+ [0.40, 0.70],
38
+ [0.50, 0.65],
39
+ ], dtype=torch.float32)
40
+
41
+ if config["label_dim"] > 2:
42
+ # Pad with fiducial astrophysics (label_mu values of those dims)
43
+ pad = torch.tensor(config["label_mu"][2:], dtype=torch.float32).unsqueeze(0)
44
+ raw_labels = torch.cat([raw_labels, pad.expand(4, -1)], dim=1)
45
+
46
+ print(f"\nGenerating samples ...")
47
+ with torch.no_grad():
48
+ out = generate(model, config, raw_labels, device=DEVICE, ddim_steps=50)
49
+
50
+ # Map [-1, 1] -> [0, 1] for visualisation
51
+ imgs = ((out.cpu().numpy() + 1) / 2).clip(0, 1)[:, 0]
52
+
53
+ # ── Display ────────────────────────────────────────────────────────────────
54
+ fig, axes = plt.subplots(1, len(imgs), figsize=(3 * len(imgs), 3.5))
55
+ for ax, img, lbl in zip(axes, imgs, raw_labels):
56
+ ax.imshow(img, cmap="magma", origin="lower", vmin=0, vmax=1)
57
+ ax.set_title(f"$\\Omega_m={lbl[0]:.2f}$, $\\sigma_8={lbl[1]:.2f}$", fontsize=10)
58
+ ax.set_xticks([]); ax.set_yticks([])
59
+ plt.suptitle("Conditional DDPM samples — CAMELS HI fields", fontweight="bold")
60
+ plt.tight_layout()
61
+ plt.savefig("inference_example.png", dpi=150, bbox_inches="tight")
62
+ print(f"\nSaved -> inference_example.png")
sample_modeling_ddpm_camels.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ modeling_ddpm_camels.py
3
+ =======================
4
+ Self-contained loader for the conditional DDPM checkpoint hosted on the Hub.
5
+ Users only need this file + diffusion_conditional.py + unet_conditional.py
6
+ + config.json + model.safetensors to run inference.
7
+ """
8
+ from __future__ import annotations
9
+ import json
10
+ from pathlib import Path
11
+ from typing import Dict, Tuple, Union
12
+
13
+ import torch
14
+
15
+ from diffusion_conditional import GaussianDiffusion, ConditionalDiffusionModel
16
+ from unet_conditional import ConditionalUNet
17
+
18
+
19
+ def build_model(config: Dict) -> ConditionalDiffusionModel:
20
+ """Instantiate the architecture from a config dict."""
21
+ unet = ConditionalUNet(
22
+ in_channels=int(config["in_channels"]),
23
+ out_channels=int(config["out_channels"]),
24
+ label_dim=int(config["label_dim"]),
25
+ base_channels=int(config["base_channels"]),
26
+ channel_multipliers=list(config["channel_multipliers"]),
27
+ attention_levels=list(config["attention_levels"]),
28
+ dropout=float(config["dropout"]),
29
+ )
30
+ diffusion = GaussianDiffusion(
31
+ timesteps=int(config["timesteps"]),
32
+ beta_start=float(config["beta_start"]),
33
+ beta_end=float(config["beta_end"]),
34
+ schedule_type=str(config["schedule_type"]),
35
+ )
36
+ return ConditionalDiffusionModel(unet, diffusion)
37
+
38
+
39
+ def load_pretrained(
40
+ model_dir: Union[str, Path],
41
+ device: str = "cuda",
42
+ ) -> Tuple[ConditionalDiffusionModel, Dict]:
43
+ """
44
+ Load the model and its config from a directory containing:
45
+ - config.json
46
+ - model.safetensors (or pytorch_model.bin as fallback)
47
+ """
48
+ model_dir = Path(model_dir)
49
+ config = json.loads((model_dir / "config.json").read_text())
50
+
51
+ model = build_model(config).to(device)
52
+
53
+ safetensors_path = model_dir / "model.safetensors"
54
+ bin_path = model_dir / "pytorch_model.bin"
55
+ if safetensors_path.exists():
56
+ from safetensors.torch import load_file
57
+ state_dict = load_file(str(safetensors_path), device=device)
58
+ elif bin_path.exists():
59
+ state_dict = torch.load(bin_path, map_location=device, weights_only=True)
60
+ else:
61
+ raise FileNotFoundError(f"No model weights in {model_dir}")
62
+
63
+ # Allow partial-match loading for backward compatibility
64
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
65
+ if missing:
66
+ print(f" Warning: missing keys: {missing[:5]}{'...' if len(missing) > 5 else ''}")
67
+ if unexpected:
68
+ print(f" Warning: unexpected keys: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
69
+
70
+ model.eval()
71
+ for p in model.parameters():
72
+ p.requires_grad_(False)
73
+
74
+ return model, config
75
+
76
+
77
+ # Convenience for one-shot inference
78
+ def generate(
79
+ model: ConditionalDiffusionModel,
80
+ config: Dict,
81
+ raw_labels: torch.Tensor, # (B, label_dim) — un-normalised cosmological params
82
+ n_samples: int = 1,
83
+ use_ddim: bool = True,
84
+ ddim_steps: int = None,
85
+ device: str = "cuda",
86
+ ) -> torch.Tensor:
87
+ """
88
+ Generate samples conditioned on raw (un-normalised) parameter values.
89
+
90
+ Returns: tensor of shape (B*n_samples, 1, H, W) in [-1, 1] model space.
91
+ """
92
+ if ddim_steps is None:
93
+ ddim_steps = config["ddim_steps_default"]
94
+
95
+ label_mu = torch.tensor(config["label_mu"], dtype=torch.float32, device=device)
96
+ label_std = torch.tensor(config["label_std"], dtype=torch.float32, device=device)
97
+
98
+ raw_labels = raw_labels.to(device)
99
+ norm_labels = (raw_labels - label_mu) / label_std
100
+ norm_labels = norm_labels.repeat_interleave(n_samples, dim=0)
101
+
102
+ H = W = config["image_size"]
103
+ return model.sample(
104
+ labels=norm_labels, channels=1, height=H, width=W,
105
+ use_ddim=use_ddim, ddim_steps=ddim_steps,
106
+ progress=False, device=device,
107
+ )