keysun89 commited on
Commit
e47f186
·
verified ·
1 Parent(s): 5fabc96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -12
app.py CHANGED
@@ -5,8 +5,8 @@ from PIL import Image
5
  from torchvision import transforms
6
  import random
7
  from huggingface_hub import hf_hub_download
8
- from generator_1 import Generator as Generator_1 # Import your StyleGAN2 generator
9
- from generator_2 import Generator as Generator_2 # Import your SRGAN generator
10
 
11
  # wts = ['trial_0_G (1).pth', 'trial_0_G (2).pth', 'trial_0_G (3).pth', 'trial_0_G (4).pth', 'trial_0_G (5).pth', 'trial_0_G.pth']
12
  wts = ['trial_0_G (2).pth', 'trial_0_G (5).pth', 'trial_0_G.pth']
@@ -14,8 +14,8 @@ random_wt = random.choice(wts)
14
 
15
  # Load trained model weights from Hugging Face Hub
16
  weights_path = hf_hub_download(
17
- repo_id="keysun89/image_generation", # Replace with your repo
18
- filename=random_wt # Replace with your weights file
19
  )
20
 
21
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -26,7 +26,7 @@ w_dim = 512
26
  img_resolution = 256 # Adjust to your training resolution
27
  img_channels = 3
28
 
29
- model = Generator_1(
30
  z_dim=z_dim,
31
  w_dim=w_dim,
32
  img_resolution=img_resolution,
@@ -38,13 +38,14 @@ model.load_state_dict(torch.load(weights_path, map_location=device))
38
  model.to(device)
39
  model.eval()
40
 
 
41
  srgan_weights = hf_hub_download(
42
- repo_id="keysun89/image_generation",
43
- filename= "genrator.pth"
44
  )
45
 
46
  # Initialize SRGAN with scale=2 (256 -> 512)
47
- srgan_model = Generator_2(img_feat=3, n_feats=64, kernel_size=3, num_block=16, scale=2)
48
  srgan_model.load_state_dict(torch.load(srgan_weights, map_location=device))
49
  srgan_model.to(device)
50
  srgan_model.eval()
@@ -70,14 +71,14 @@ def generate():
70
  # Step 2: Upscale to 512x512 with SRGAN
71
  img_256_tensor = transform(img_256_pil).unsqueeze(0).to(device)
72
 
73
- # Generate high-resolution image
74
- img_512 = srgan_model(img_256_tensor)
75
 
76
  # Convert to PIL Image (512x512)
77
  img_512_np = img_512.squeeze(0).cpu().numpy()
78
  img_512_np = np.transpose(img_512_np, (1, 2, 0)) # CHW to HWC
79
- # Denormalize if needed
80
- img_512_np = (img_512_np * 127.5 + 128).clip(0, 255).astype(np.uint8)
81
  img_512_pil = Image.fromarray(img_512_np)
82
 
83
  return img_256_pil, img_512_pil
 
5
  from torchvision import transforms
6
  import random
7
  from huggingface_hub import hf_hub_download
8
+ from generator_1 import Generator as StyleGANGenerator # Import your StyleGAN2 generator
9
+ from generator_2 import Generator as SRGANGenerator # Import your SRGAN generator
10
 
11
  # wts = ['trial_0_G (1).pth', 'trial_0_G (2).pth', 'trial_0_G (3).pth', 'trial_0_G (4).pth', 'trial_0_G (5).pth', 'trial_0_G.pth']
12
  wts = ['trial_0_G (2).pth', 'trial_0_G (5).pth', 'trial_0_G.pth']
 
14
 
15
  # Load trained model weights from Hugging Face Hub
16
  weights_path = hf_hub_download(
17
+ repo_id="keysun89/img_generation", # Fixed repo name
18
+ filename=random_wt
19
  )
20
 
21
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
26
  img_resolution = 256 # Adjust to your training resolution
27
  img_channels = 3
28
 
29
+ model = StyleGANGenerator(
30
  z_dim=z_dim,
31
  w_dim=w_dim,
32
  img_resolution=img_resolution,
 
38
  model.to(device)
39
  model.eval()
40
 
41
+ wt_2 = 'genrator.pth'
42
  srgan_weights = hf_hub_download(
43
+ repo_id="keysun89/img_generation", # Fixed repo name
44
+ filename=wt_2
45
  )
46
 
47
  # Initialize SRGAN with scale=2 (256 -> 512)
48
+ srgan_model = SRGANGenerator(img_feat=3, n_feats=64, kernel_size=3, num_block=16, scale=2)
49
  srgan_model.load_state_dict(torch.load(srgan_weights, map_location=device))
50
  srgan_model.to(device)
51
  srgan_model.eval()
 
71
  # Step 2: Upscale to 512x512 with SRGAN
72
  img_256_tensor = transform(img_256_pil).unsqueeze(0).to(device)
73
 
74
+ # Generate high-resolution image (SRGAN returns tuple: image, features)
75
+ img_512, _ = srgan_model(img_256_tensor)
76
 
77
  # Convert to PIL Image (512x512)
78
  img_512_np = img_512.squeeze(0).cpu().numpy()
79
  img_512_np = np.transpose(img_512_np, (1, 2, 0)) # CHW to HWC
80
+ # Denormalize from tanh output [-1, 1] to [0, 255]
81
+ img_512_np = (img_512_np * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
82
  img_512_pil = Image.fromarray(img_512_np)
83
 
84
  return img_256_pil, img_512_pil