uhdessai commited on
Commit
934e4be
·
1 Parent(s): 9e4dce2

initial commit

Browse files
Files changed (4) hide show
  1. .gitignore +34 -0
  2. app.py +44 -0
  3. gen_images.py +145 -0
  4. requirements.txt +7 -0
.gitignore ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore outputs (generated images)
2
+ outputs/
3
+ *.png
4
+
5
+ # Ignore model weights (handled separately, e.g. via Git LFS or manual upload)
6
+ model/*.pkl
7
+
8
+ # Python cache
9
+ __pycache__/
10
+ *.py[cod]
11
+ *.so
12
+
13
+ # Jupyter notebooks checkpoint
14
+ .ipynb_checkpoints/
15
+
16
+ # System files
17
+ .DS_Store
18
+ Thumbs.db
19
+
20
+ # Environment or build artifacts
21
+ env/
22
+ venv/
23
+ *.egg-info/
24
+ dist/
25
+ build/
26
+ .cache/
27
+ *.log
28
+
29
+ # VSCode and IDE settings
30
+ .vscode/
31
+ .idea/
32
+
33
+ # Hugging Face specific
34
+ *.sagemaker/
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import subprocess
3
+ import os
4
+ import random
5
+ from PIL import Image
6
+
7
+ # Paths
8
+ OUTPUT_DIR = "outputs"
9
+ MODEL_PATH = "blazer_model.pkl" # Adjusted for local Hugging Face repo structure
10
+
11
+ # Ensure the output directory exists
12
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
13
+
14
+ # Function to generate images using StyleGAN3
15
+ def generate_images():
16
+ command = f"python stylegan3/gen_images.py --outdir={OUTPUT_DIR} --trunc=1 --seeds='3-5,7,9,12-14,16-26,29,31,32,34,40,41' --network={MODEL_PATH}"
17
+ try:
18
+ subprocess.run(command, shell=True, check=True)
19
+ except subprocess.CalledProcessError as e:
20
+ return f"Error generating images: {e}"
21
+
22
+ # Function to select 5 random images
23
+ def get_random_images():
24
+ image_files = [f for f in os.listdir(OUTPUT_DIR) if f.endswith(".png")]
25
+ if len(image_files) < 10:
26
+ generate_images()
27
+ image_files = [f for f in os.listdir(OUTPUT_DIR) if f.endswith(".png")]
28
+ random_images = random.sample(image_files, min(10, len(image_files)))
29
+ return [Image.open(os.path.join(OUTPUT_DIR, img)) for img in random_images]
30
+
31
+ # Gradio function
32
+ def generate_and_display():
33
+ generate_images()
34
+ return get_random_images()
35
+
36
+ # UI
37
+ with gr.Blocks() as demo:
38
+ gr.Markdown("# 🎨 AI-Generated Clothing Designs - Blazers")
39
+ generate_button = gr.Button("Generate New Designs")
40
+ output_gallery = gr.Gallery(label="Generated Designs", columns=5, rows=2)
41
+ generate_button.click(fn=generate_and_display, outputs=output_gallery)
42
+
43
+ if __name__ == "__main__":
44
+ demo.launch()
gen_images.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Generate images using pretrained network pickle."""
10
+
11
+ import os
12
+ import re
13
+ from typing import List, Optional, Tuple, Union
14
+
15
+ import click
16
+ import dnnlib
17
+ import numpy as np
18
+ import PIL.Image
19
+ import torch
20
+
21
+ import legacy
22
+
23
+ #----------------------------------------------------------------------------
24
+
25
+ def parse_range(s: Union[str, List]) -> List[int]:
26
+ '''Parse a comma separated list of numbers or ranges and return a list of ints.
27
+
28
+ Example: '1,2,5-10' returns [1, 2, 5, 6, 7]
29
+ '''
30
+ if isinstance(s, list): return s
31
+ ranges = []
32
+ range_re = re.compile(r'^(\d+)-(\d+)$')
33
+ for p in s.split(','):
34
+ m = range_re.match(p)
35
+ if m:
36
+ ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
37
+ else:
38
+ ranges.append(int(p))
39
+ return ranges
40
+
41
+ #----------------------------------------------------------------------------
42
+
43
+ def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]:
44
+ '''Parse a floating point 2-vector of syntax 'a,b'.
45
+
46
+ Example:
47
+ '0,1' returns (0,1)
48
+ '''
49
+ if isinstance(s, tuple): return s
50
+ parts = s.split(',')
51
+ if len(parts) == 2:
52
+ return (float(parts[0]), float(parts[1]))
53
+ raise ValueError(f'cannot parse 2-vector {s}')
54
+
55
+ #----------------------------------------------------------------------------
56
+
57
+ def make_transform(translate: Tuple[float,float], angle: float):
58
+ m = np.eye(3)
59
+ s = np.sin(angle/360.0*np.pi*2)
60
+ c = np.cos(angle/360.0*np.pi*2)
61
+ m[0][0] = c
62
+ m[0][1] = s
63
+ m[0][2] = translate[0]
64
+ m[1][0] = -s
65
+ m[1][1] = c
66
+ m[1][2] = translate[1]
67
+ return m
68
+
69
+ #----------------------------------------------------------------------------
70
+
71
+ @click.command()
72
+ @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
73
+ @click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', required=True)
74
+ @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
75
+ @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
76
+ @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
77
+ @click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2')
78
+ @click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE')
79
+ @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
80
+ def generate_images(
81
+ network_pkl: str,
82
+ seeds: List[int],
83
+ truncation_psi: float,
84
+ noise_mode: str,
85
+ outdir: str,
86
+ translate: Tuple[float,float],
87
+ rotate: float,
88
+ class_idx: Optional[int]
89
+ ):
90
+ """Generate images using pretrained network pickle.
91
+
92
+ Examples:
93
+
94
+ \b
95
+ # Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left).
96
+ python gen_images.py --outdir=out --trunc=1 --seeds=2 \\
97
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl
98
+
99
+ \b
100
+ # Generate uncurated images with truncation using the MetFaces-U dataset
101
+ python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\
102
+ --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl
103
+ """
104
+
105
+ print('Loading networks from "%s"...' % network_pkl)
106
+ device = torch.device('cuda')
107
+ with dnnlib.util.open_url(network_pkl) as f:
108
+ G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
109
+
110
+ os.makedirs(outdir, exist_ok=True)
111
+
112
+ # Labels.
113
+ label = torch.zeros([1, G.c_dim], device=device)
114
+ if G.c_dim != 0:
115
+ if class_idx is None:
116
+ raise click.ClickException('Must specify class label with --class when using a conditional network')
117
+ label[:, class_idx] = 1
118
+ else:
119
+ if class_idx is not None:
120
+ print ('warn: --class=lbl ignored when running on an unconditional network')
121
+
122
+ # Generate images.
123
+ for seed_idx, seed in enumerate(seeds):
124
+ print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
125
+ z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
126
+
127
+ # Construct an inverse rotation/translation matrix and pass to the generator. The
128
+ # generator expects this matrix as an inverse to avoid potentially failing numerical
129
+ # operations in the network.
130
+ if hasattr(G.synthesis, 'input'):
131
+ m = make_transform(translate, rotate)
132
+ m = np.linalg.inv(m)
133
+ G.synthesis.input.transform.copy_(torch.from_numpy(m))
134
+
135
+ img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
136
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
137
+ PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
138
+
139
+
140
+ #----------------------------------------------------------------------------
141
+
142
+ if __name__ == "__main__":
143
+ generate_images() # pylint: disable=no-value-for-parameter
144
+
145
+ #----------------------------------------------------------------------------
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio==4.26.0
2
+ Pillow==10.2.0
3
+ torch==2.0.1
4
+ torchvision==0.15.2
5
+ torchaudio==2.0.2
6
+ ninja==1.11.1
7
+