keysun89 commited on
Commit
3eda9ab
·
verified ·
1 Parent(s): 8e79a95

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import random
6
+ from huggingface_hub import hf_hub_download
7
+ from generator import Generator # Import your generator class
8
+
9
+ # Import your generator class
10
+ # from generator import Generator # Uncomment and adjust to your file
11
+
12
+ wts = ['trial_0_G (1).pth' , 'trial_0_G (2).pth' , 'trial_0_G (3).pth' , 'trial_0_G (4).pth' , 'trial_0_G (5).pth' , 'trial_0_G.pth' ]
13
+ random_wt = random.choice(wts)
14
+
15
+ # Load trained model weights from Hugging Face Hub
16
+ weights_path = hf_hub_download(
17
+ repo_id="keysun89/image_generation", # Replace with your repo
18
+ filename= random_wt # Replace with your weights file
19
+ )
20
+
21
+
22
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
+
24
+ # Configure your generator parameters
25
+ z_dim = 512
26
+ w_dim = 512
27
+ img_resolution = 256 # Adjust to your training resolution
28
+ img_channels = 3
29
+
30
+ model = Generator(
31
+ z_dim=z_dim,
32
+ w_dim=w_dim,
33
+ img_resolution=img_resolution,
34
+ img_channels=img_channels
35
+ )
36
+
37
+ # Load weights
38
+ model.load_state_dict(torch.load(weights_path, map_location=device))
39
+ model.to(device)
40
+ model.eval()
41
+
42
+ def generate():
43
+ """Generate a random image"""
44
+ with torch.no_grad():
45
+ # Generate random latent vector
46
+ z = torch.randn(1, z_dim, device=device)
47
+
48
+ # Generate image
49
+ img = model(z, use_truncation=True, truncation_psi=0.7)
50
+
51
+ # Convert to PIL Image
52
+ img = img.squeeze(0).cpu().numpy()
53
+ img = np.transpose(img, (1, 2, 0)) # CHW to HWC
54
+ img = (img * 127.5 + 128).clip(0, 255).astype(np.uint8)
55
+
56
+ return Image.fromarray(img)
57
+
58
+ # Gradio interface
59
+ demo = gr.Interface(
60
+ fn=generate,
61
+ inputs=None,
62
+ outputs=gr.Image(type="pil"),
63
+ title="StyleGAN2 Image Generator",
64
+ description="Click 'Submit' or refresh the page to generate a new random image",
65
+ allow_flagging="never"
66
+ )
67
+
68
+ if __name__ == "__main__":
69
+ demo.launch()