AxionLab-official commited on
Commit
fbce1fc
·
verified ·
1 Parent(s): 40b58e8

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +122 -0
README.md CHANGED
@@ -79,12 +79,134 @@ Limitations:
79
 
80
  ### Generate images
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  ```bash
83
  python generate.py \
84
  --checkpoint model.pt \
85
  --n_images 8 \
86
  --steps 50 \
87
  --seed 42
 
88
  📁 Output
89
 
90
  Generated images are saved as a horizontal grid:
 
79
 
80
  ### Generate images
81
 
82
+ ## THE INITIAL IDEA WAS A STUDENT U-NET FROM A TEACHER U-NET, BUT THIS WAS DISCONTINUED BECAUSE THE TEACHER WAS INITIALIZATED WITH RANDOM WEIGHTS, THAT WOULD KILL THE STUDENT LEARNING
83
+
84
+
85
+ ```python
86
+
87
+ import argparse
88
+ import torch
89
+ from pathlib import Path
90
+
91
+ from train import StudentUNet, DDPMScheduler, Config
92
+
93
+
94
+ # ----------------------------
95
+ # Sampling (com controle de steps)
96
+ # ----------------------------
97
+ @torch.no_grad()
98
+ def generate_samples(model, scheduler, n=4, steps=50, device="cpu", dtype=torch.float32):
99
+ model.eval()
100
+
101
+ x = torch.randn(n, 3, cfg.image_size, cfg.image_size, device=device, dtype=dtype)
102
+
103
+ # 🔥 Usa menos steps → muito mais rápido
104
+ step_size = scheduler.T // steps
105
+ timesteps = list(range(0, scheduler.T, step_size))
106
+ timesteps = list(reversed(timesteps))
107
+
108
+ for t_val in timesteps:
109
+ t = torch.full((n,), t_val, device=device, dtype=torch.long)
110
+ noise_pred = model(x, t)
111
+
112
+ if t_val > 0:
113
+ ab = scheduler.alpha_bar[t_val].to(x.dtype)
114
+ prev_t = max(t_val - step_size, 0)
115
+ ab_prev = scheduler.alpha_bar[prev_t].to(x.dtype)
116
+
117
+ beta_t = 1.0 - (ab / ab_prev)
118
+ alpha_t = 1.0 - beta_t
119
+
120
+ mean = (1.0 / alpha_t.sqrt()) * (
121
+ x - (beta_t / (1.0 - ab).sqrt()) * noise_pred
122
+ )
123
+
124
+ sigma = beta_t.sqrt()
125
+ x = mean + sigma * torch.randn_like(x)
126
+ else:
127
+ x = scheduler.predict_x0(x, noise_pred, t)
128
+
129
+ model.train()
130
+ return x.clamp(-1, 1)
131
+
132
+
133
+ # ----------------------------
134
+ # Save
135
+ # ----------------------------
136
+ def save_samples(samples, path: Path):
137
+ samples = (samples + 1) / 2
138
+ samples = (samples * 255).byte().permute(0, 2, 3, 1).cpu().numpy()
139
+
140
+ from PIL import Image
141
+
142
+ n = len(samples)
143
+ w = samples.shape[1]
144
+
145
+ grid = Image.new("RGB", (n * w, w))
146
+
147
+ for i, s in enumerate(samples):
148
+ grid.paste(Image.fromarray(s), (i * w, 0))
149
+
150
+ path.parent.mkdir(parents=True, exist_ok=True)
151
+ grid.save(path)
152
+
153
+
154
+ # ----------------------------
155
+ # Main
156
+ # ----------------------------
157
+ def main():
158
+ parser = argparse.ArgumentParser()
159
+ parser.add_argument("--checkpoint", type=str, required=True)
160
+ parser.add_argument("--n_images", type=int, default=8)
161
+ parser.add_argument("--steps", type=int, default=50)
162
+ parser.add_argument("--seed", type=int, default=42)
163
+ parser.add_argument("--out", type=str, default="outputs/generated.png")
164
+
165
+ args = parser.parse_args()
166
+
167
+ # Seed
168
+ torch.manual_seed(args.seed)
169
+
170
+ # Load checkpoint
171
+ ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
172
+ global cfg
173
+ cfg = ckpt.get("config", Config())
174
+
175
+ # Model
176
+ model = StudentUNet(cfg)
177
+ model.load_state_dict(ckpt["model_state"])
178
+ model.eval()
179
+
180
+ # Scheduler
181
+ scheduler = DDPMScheduler(cfg.timesteps, cfg.beta_start, cfg.beta_end)
182
+
183
+ print(f"\n🚀 Generating {args.n_images} images")
184
+ print(f"⚙️ Steps: {args.steps} | Seed: {args.seed}")
185
+
186
+ samples = generate_samples(
187
+ model,
188
+ scheduler,
189
+ n=args.n_images,
190
+ steps=args.steps,
191
+ dtype=cfg.dtype
192
+ )
193
+
194
+ save_samples(samples, Path(args.out))
195
+
196
+ print(f"✅ Saved to: {args.out}")
197
+
198
+
199
+ if __name__ == "__main__":
200
+ main()
201
+ ```
202
+
203
  ```bash
204
  python generate.py \
205
  --checkpoint model.pt \
206
  --n_images 8 \
207
  --steps 50 \
208
  --seed 42
209
+
210
  📁 Output
211
 
212
  Generated images are saved as a horizontal grid: