Update README.md
Browse files
README.md
CHANGED
|
@@ -79,12 +79,134 @@ Limitations:
|
|
| 79 |
|
| 80 |
### Generate images
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
```bash
|
| 83 |
python generate.py \
|
| 84 |
--checkpoint model.pt \
|
| 85 |
--n_images 8 \
|
| 86 |
--steps 50 \
|
| 87 |
--seed 42
|
|
|
|
| 88 |
📁 Output
|
| 89 |
|
| 90 |
Generated images are saved as a horizontal grid:
|
|
|
|
| 79 |
|
| 80 |
### Generate images
|
| 81 |
|
| 82 |
+
## THE INITIAL IDEA WAS A STUDENT U-NET FROM A TEACHER U-NET, BUT THIS WAS DISCONTINUED BECAUSE THE TEACHER WAS INITIALIZATED WITH RANDOM WEIGHTS, THAT WOULD KILL THE STUDENT LEARNING
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
```python
|
| 86 |
+
|
| 87 |
+
import argparse
|
| 88 |
+
import torch
|
| 89 |
+
from pathlib import Path
|
| 90 |
+
|
| 91 |
+
from train import StudentUNet, DDPMScheduler, Config
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# ----------------------------
|
| 95 |
+
# Sampling (com controle de steps)
|
| 96 |
+
# ----------------------------
|
| 97 |
+
@torch.no_grad()
|
| 98 |
+
def generate_samples(model, scheduler, n=4, steps=50, device="cpu", dtype=torch.float32):
|
| 99 |
+
model.eval()
|
| 100 |
+
|
| 101 |
+
x = torch.randn(n, 3, cfg.image_size, cfg.image_size, device=device, dtype=dtype)
|
| 102 |
+
|
| 103 |
+
# 🔥 Usa menos steps → muito mais rápido
|
| 104 |
+
step_size = scheduler.T // steps
|
| 105 |
+
timesteps = list(range(0, scheduler.T, step_size))
|
| 106 |
+
timesteps = list(reversed(timesteps))
|
| 107 |
+
|
| 108 |
+
for t_val in timesteps:
|
| 109 |
+
t = torch.full((n,), t_val, device=device, dtype=torch.long)
|
| 110 |
+
noise_pred = model(x, t)
|
| 111 |
+
|
| 112 |
+
if t_val > 0:
|
| 113 |
+
ab = scheduler.alpha_bar[t_val].to(x.dtype)
|
| 114 |
+
prev_t = max(t_val - step_size, 0)
|
| 115 |
+
ab_prev = scheduler.alpha_bar[prev_t].to(x.dtype)
|
| 116 |
+
|
| 117 |
+
beta_t = 1.0 - (ab / ab_prev)
|
| 118 |
+
alpha_t = 1.0 - beta_t
|
| 119 |
+
|
| 120 |
+
mean = (1.0 / alpha_t.sqrt()) * (
|
| 121 |
+
x - (beta_t / (1.0 - ab).sqrt()) * noise_pred
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
sigma = beta_t.sqrt()
|
| 125 |
+
x = mean + sigma * torch.randn_like(x)
|
| 126 |
+
else:
|
| 127 |
+
x = scheduler.predict_x0(x, noise_pred, t)
|
| 128 |
+
|
| 129 |
+
model.train()
|
| 130 |
+
return x.clamp(-1, 1)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# ----------------------------
|
| 134 |
+
# Save
|
| 135 |
+
# ----------------------------
|
| 136 |
+
def save_samples(samples, path: Path):
|
| 137 |
+
samples = (samples + 1) / 2
|
| 138 |
+
samples = (samples * 255).byte().permute(0, 2, 3, 1).cpu().numpy()
|
| 139 |
+
|
| 140 |
+
from PIL import Image
|
| 141 |
+
|
| 142 |
+
n = len(samples)
|
| 143 |
+
w = samples.shape[1]
|
| 144 |
+
|
| 145 |
+
grid = Image.new("RGB", (n * w, w))
|
| 146 |
+
|
| 147 |
+
for i, s in enumerate(samples):
|
| 148 |
+
grid.paste(Image.fromarray(s), (i * w, 0))
|
| 149 |
+
|
| 150 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 151 |
+
grid.save(path)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# ----------------------------
|
| 155 |
+
# Main
|
| 156 |
+
# ----------------------------
|
| 157 |
+
def main():
|
| 158 |
+
parser = argparse.ArgumentParser()
|
| 159 |
+
parser.add_argument("--checkpoint", type=str, required=True)
|
| 160 |
+
parser.add_argument("--n_images", type=int, default=8)
|
| 161 |
+
parser.add_argument("--steps", type=int, default=50)
|
| 162 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 163 |
+
parser.add_argument("--out", type=str, default="outputs/generated.png")
|
| 164 |
+
|
| 165 |
+
args = parser.parse_args()
|
| 166 |
+
|
| 167 |
+
# Seed
|
| 168 |
+
torch.manual_seed(args.seed)
|
| 169 |
+
|
| 170 |
+
# Load checkpoint
|
| 171 |
+
ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
|
| 172 |
+
global cfg
|
| 173 |
+
cfg = ckpt.get("config", Config())
|
| 174 |
+
|
| 175 |
+
# Model
|
| 176 |
+
model = StudentUNet(cfg)
|
| 177 |
+
model.load_state_dict(ckpt["model_state"])
|
| 178 |
+
model.eval()
|
| 179 |
+
|
| 180 |
+
# Scheduler
|
| 181 |
+
scheduler = DDPMScheduler(cfg.timesteps, cfg.beta_start, cfg.beta_end)
|
| 182 |
+
|
| 183 |
+
print(f"\n🚀 Generating {args.n_images} images")
|
| 184 |
+
print(f"⚙️ Steps: {args.steps} | Seed: {args.seed}")
|
| 185 |
+
|
| 186 |
+
samples = generate_samples(
|
| 187 |
+
model,
|
| 188 |
+
scheduler,
|
| 189 |
+
n=args.n_images,
|
| 190 |
+
steps=args.steps,
|
| 191 |
+
dtype=cfg.dtype
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
save_samples(samples, Path(args.out))
|
| 195 |
+
|
| 196 |
+
print(f"✅ Saved to: {args.out}")
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
if __name__ == "__main__":
|
| 200 |
+
main()
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
```bash
|
| 204 |
python generate.py \
|
| 205 |
--checkpoint model.pt \
|
| 206 |
--n_images 8 \
|
| 207 |
--steps 50 \
|
| 208 |
--seed 42
|
| 209 |
+
|
| 210 |
📁 Output
|
| 211 |
|
| 212 |
Generated images are saved as a horizontal grid:
|