SkillForge45 commited on
Commit
53ac1bf
·
verified ·
1 Parent(s): 67dc5db

Create generate.py

Browse files
Files changed (1) hide show
  1. generate.py +37 -0
generate.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from model import DiffusionModel, UNet
3
+ from torchvision.utils import save_image
4
+ import argparse
5
+ from PIL import Image
6
+
7
+ def generate(prompts, model_path="diffusion_model.pth", image_size=256, device="cuda"):
8
+ # Load model
9
+ model = UNet().to(device)
10
+ model.load_state_dict(torch.load(model_path, map_location=device))
11
+ model.eval()
12
+
13
+ # Setup diffusion
14
+ betas = torch.linspace(1e-4, 0.02, 1000).to(device)
15
+ diffusion = DiffusionModel(model, betas, device)
16
+
17
+ # Generate images
18
+ with torch.no_grad():
19
+ images = diffusion.sample(prompts, image_size=image_size, batch_size=len(prompts))
20
+
21
+ # Save images
22
+ os.makedirs("generated", exist_ok=True)
23
+ for i, img in enumerate(images):
24
+ img = Image.fromarray(img.permute(1, 2, 0).cpu().numpy())
25
+ img.save(f"generated/sample_{i}.png")
26
+
27
+ print(f"Generated {len(images)} images saved in 'generated' folder")
28
+
29
+ if __name__ == "__main__":
30
+ parser = argparse.ArgumentParser()
31
+ parser.add_argument("--prompts", nargs="+", required=True, help="Text prompts for generation")
32
+ parser.add_argument("--model", default="diffusion_model.pth", help="Path to trained model")
33
+ parser.add_argument("--size", type=int, default=256, help="Image size")
34
+ args = parser.parse_args()
35
+
36
+ device = "cuda" if torch.cuda.is_available() else "cpu"
37
+ generate(args.prompts, args.model, args.size, device)