AaSiKu commited on
Commit
c5dc62d
·
verified ·
1 Parent(s): c045103

Uploading the model to huggingface

Browse files
Files changed (5) hide show
  1. app.py +38 -0
  2. helper.py +125 -0
  3. netG_A2B_epoch130.pth +3 -0
  4. netG_B2A_epoch130.pth +3 -0
  5. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from helper import *
3
+
4
+ # Set the page title
5
+ st.title("Aging Deaging - Assignment 4")
6
+
7
+ # Create columns for input and output sections
8
+ col1, col2 = st.columns(2)
9
+
10
+ # Input Section
11
+ with col1:
12
+ st.header("Input")
13
+ uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
14
+
15
+ if uploaded_image is not None:
16
+ # Display the uploaded image
17
+ st.image(uploaded_image, caption="Uploaded Image")
18
+
19
+ # Display the selected option
20
+
21
+
22
+ # Output Section
23
+ with col2:
24
+ st.header("Output")
25
+ age_conversion_option = st.radio("Select age conversion option", ("Old to Young", "Young to Old"))
26
+ st.write(f"Selected conversion: {age_conversion_option}")
27
+ if st.button("Generate"):
28
+ if uploaded_image is not None:
29
+ # Here you can add your image processing code
30
+ # For now, we'll just display the uploaded image as a placeholder
31
+ if age_conversion_option == "Young to Old":
32
+ processed_image = generate_Y2O(uploaded_image)
33
+ st.image(processed_image, caption="Old you", use_column_width=True)
34
+ elif age_conversion_option == "Old to Young":
35
+ processed_image = generate_O2Y(uploaded_image)
36
+ st.image(processed_image, caption="Young you", use_column_width=True)
37
+ else:
38
+ st.warning("Please upload an image before clicking Generate")
helper.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import cv2
5
+ from torchvision import transforms,datasets
6
+ from PIL import Image
7
+ input_nc = 3
8
+ output_nc = 3
9
+ class SEBlock(nn.Module):
10
+ def __init__(self, channel, reduction=16):
11
+ super(SEBlock, self).__init__()
12
+ self.fc = nn.Sequential(
13
+ nn.AdaptiveAvgPool2d(1), # Squeeze: output size (N, channel, 1, 1)
14
+ nn.Conv2d(channel, channel // reduction, 1),
15
+ nn.ReLU(inplace=True),
16
+ nn.Conv2d(channel // reduction, channel, 1),
17
+ nn.Sigmoid() # Excitation: channel weights between 0 and 1
18
+ )
19
+
20
+ def forward(self, x):
21
+ weights = self.fc(x)
22
+ return x * weights # channel-wise multiplication
23
+
24
+ class ResnetBlock(nn.Module):
25
+ def __init__(self, dim, reduction=16):
26
+ super(ResnetBlock, self).__init__()
27
+ self.conv_block = self.build_conv_block(dim)
28
+ self.se = SEBlock(dim, reduction)
29
+
30
+ def build_conv_block(self, dim):
31
+ conv_block = [
32
+ nn.ReflectionPad2d(1),
33
+ nn.Conv2d(dim, dim, kernel_size=3, padding=0),
34
+ nn.InstanceNorm2d(dim),
35
+ nn.ReLU(True),
36
+ nn.ReflectionPad2d(1),
37
+ nn.Conv2d(dim, dim, kernel_size=3, padding=0),
38
+ nn.InstanceNorm2d(dim)
39
+ ]
40
+ return nn.Sequential(*conv_block)
41
+
42
+ def forward(self, x):
43
+ out = self.conv_block(x)
44
+ out = self.se(out) # apply squeeze-and-excitation
45
+ return x + out
46
+
47
+ class GeneratorResNet(nn.Module):
48
+ def __init__(self, input_nc, output_nc, n_residual_blocks=9):
49
+ super(GeneratorResNet, self).__init__()
50
+
51
+ # Initial convolution block
52
+ model = [nn.ReflectionPad2d(3),
53
+ nn.Conv2d(input_nc, 64, 7),
54
+ nn.InstanceNorm2d(64),
55
+ nn.ReLU(inplace=True)]
56
+
57
+ # Downsampling
58
+ in_features = 64
59
+ out_features = in_features * 2
60
+ for _ in range(2):
61
+ model += [nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
62
+ nn.InstanceNorm2d(out_features),
63
+ nn.ReLU(inplace=True)]
64
+ in_features = out_features
65
+ out_features = in_features * 2
66
+
67
+ # Residual blocks
68
+ for _ in range(n_residual_blocks):
69
+ model += [ResnetBlock(in_features)]
70
+
71
+ # Upsampling
72
+ out_features = in_features // 2
73
+ for _ in range(2):
74
+ model += [nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
75
+ nn.InstanceNorm2d(out_features),
76
+ nn.ReLU(inplace=True)]
77
+ in_features = out_features
78
+ out_features = in_features // 2
79
+
80
+ # Output layer
81
+ model += [nn.ReflectionPad2d(3),
82
+ nn.Conv2d(64, output_nc, 7),
83
+ nn.Tanh()]
84
+
85
+ self.model = nn.Sequential(*model)
86
+
87
+ def forward(self, x):
88
+ return self.model(x)
89
+
90
+
91
+ netG_A2B = GeneratorResNet(input_nc, output_nc)
92
+ netG_B2A = GeneratorResNet(input_nc, output_nc)
93
+ # Load weights for netG_A2B
94
+ device = 'cpu'
95
+ netG_A2B.load_state_dict(torch.load('./netG_A2B_epoch130.pth',map_location=device))
96
+
97
+ # Load weights for netG_B2A
98
+ netG_B2A.load_state_dict(torch.load('./netG_B2A_epoch130.pth',map_location=device))
99
+
100
+ def generate_Y2O(uploaded_image):
101
+
102
+ image = Image.open(uploaded_image) # Open image using PIL
103
+
104
+ # Convert PIL image to OpenCV format (BGR)
105
+ open_cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
106
+ img = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB)
107
+ img = cv2.resize(img, (128,128))
108
+ to_tensor = transforms.ToTensor()
109
+ tensor = to_tensor(img)
110
+ old = netG_A2B(tensor)
111
+ return (old.detach().permute(1, 2, 0).numpy()+1)/2
112
+
113
+
114
+ def generate_O2Y(uploaded_image):
115
+
116
+ image = Image.open(uploaded_image) # Open image using PIL
117
+
118
+ # Convert PIL image to OpenCV format (BGR)
119
+ open_cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
120
+ img = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB)
121
+ img = cv2.resize(img, (128,128))
122
+ to_tensor = transforms.ToTensor()
123
+ tensor = to_tensor(img)
124
+ young = netG_B2A(tensor)
125
+ return (young.detach().permute(1, 2, 0).numpy()+1)/2
netG_A2B_epoch130.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:992c0132e46cf2294651f9d5c827a9c28ddb101d9531cf232a46792da0a82eed
3
+ size 45851722
netG_B2A_epoch130.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8299523a9f28f7d28e5ec6b4431c847f7a216b4c4a33daf474e3f70b7ea204f
3
+ size 45851722
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ Pillow
2
+ streamlit
3
+ torch
4
+ torchvision
5
+ numpy
6
+ opencv-python