My-AI-Projects commited on
Commit
4c6e97b
·
1 Parent(s): 23757c0
Files changed (1) hide show
  1. app.py +214 -0
app.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import cv2
3
+ import os
4
+ import numpy as np
5
+ import torch
6
+ import torchvision.transforms as T
7
+ from PIL import Image
8
+ import io # Ensure this import is present
9
+ from model.u2net import U2NET # Replace with the actual import path
10
+
11
+ # Constants
12
+ MAX_FILE_SIZE = 5 * 1024 * 1024 # 5MB
13
+
14
+ # Device configuration
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ # Load face detection model
18
+ face_cascade = cv2.CascadeClassifier('haarcascade_frontalface_default.xml')
19
+
20
+ # Preprocessing function
21
+ def preprocess_image(image):
22
+ transform = T.Compose([
23
+ T.ToTensor(),
24
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
25
+ ])
26
+ return transform(image).unsqueeze(0).to(device)
27
+
28
+ # Model loading function
29
+ def load_model(model, model_path, device):
30
+ if not os.path.exists(model_path):
31
+ raise FileNotFoundError(f"Model weights file not found: {model_path}")
32
+ model.load_state_dict(torch.load(model_path, map_location=device))
33
+ model = model.to(device)
34
+ return model
35
+
36
+ # Initialize U2Net model
37
+ u2net = U2NET(in_ch=3, out_ch=1)
38
+
39
+ # Load pre-trained model (replace with your model path)
40
+ try:
41
+ u2net = load_model(u2net, "u2net_portrait.pth", device)
42
+ except FileNotFoundError as e:
43
+ st.error(f"Error: {e}")
44
+ st.stop()
45
+
46
+ # Function to detect the largest face in an image
47
+ def detect_single_face(face_cascade, img):
48
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
49
+ faces = face_cascade.detectMultiScale(gray, 1.1, 4)
50
+ if len(faces) == 0:
51
+ st.warning("No face detected; processing the entire image.")
52
+ return None
53
+
54
+ # Filter to keep the largest face
55
+ largest_face = max(faces, key=lambda b: b[2] * b[3])
56
+ return largest_face
57
+
58
+ # Function to crop and resize face region
59
+ def crop_face(img, face):
60
+ if face is None:
61
+ return img
62
+
63
+ (x, y, w, h) = face
64
+ height, width = img.shape[0:2]
65
+
66
+ # Define padding
67
+ lpad = int(float(w) * 0.4)
68
+ rpad = int(float(w) * 0.4)
69
+ tpad = int(float(h) * 0.6)
70
+ bpad = int(float(h) * 0.2)
71
+
72
+ left, right = max(x - lpad, 0), min(x + w + rpad, width)
73
+ top, bottom = max(y - tpad, 0), min(y + h + bpad, height)
74
+
75
+ im_face = img[top:bottom, left:right]
76
+
77
+ if len(im_face.shape) == 2:
78
+ im_face = np.repeat(im_face[:, :, np.newaxis], 3, axis=2)
79
+
80
+ # Pad to make image square
81
+ hf, wf = im_face.shape[0:2]
82
+ if hf > wf:
83
+ wfp = (hf - wf) // 2
84
+ im_face = np.pad(im_face, ((0, 0), (wfp, wfp), (0, 0)), mode='constant', constant_values=255)
85
+ elif wf > hf:
86
+ hfp = (wf - hf) // 2
87
+ im_face = np.pad(im_face, ((hfp, hfp), (0, 0), (0, 0)), mode='constant', constant_values=255)
88
+
89
+ im_face = cv2.resize(im_face, (512, 512), interpolation=cv2.INTER_AREA)
90
+ return im_face
91
+
92
+ # Normalize prediction function
93
+ def normPRED(d):
94
+ ma = torch.max(d)
95
+ mi = torch.min(d)
96
+ return (d - mi) / (ma - mi)
97
+
98
+ # Inference function
99
+ def inference(net, input):
100
+ input = input / np.max(input)
101
+ tmpImg = np.zeros((input.shape[0], input.shape[1], 3))
102
+ tmpImg[:, :, 0] = (input[:, :, 2] - 0.406) / 0.225
103
+ tmpImg[:, :, 1] = (input[:, :, 1] - 0.456) / 0.224
104
+ tmpImg[:, :, 2] = (input[:, :, 0] - 0.485) / 0.229
105
+ tmpImg = tmpImg.transpose((2, 0, 1))[np.newaxis, :, :, :]
106
+ tmpImg = torch.from_numpy(tmpImg).type(torch.FloatTensor)
107
+
108
+ if torch.cuda.is_available():
109
+ tmpImg = tmpImg.cuda()
110
+
111
+ with torch.no_grad():
112
+ d1, _, _, _, _, _, _ = net(tmpImg)
113
+
114
+ pred = 1.0 - d1[:, 0, :, :]
115
+ pred = normPRED(pred)
116
+ pred = pred.squeeze().cpu().data.numpy()
117
+ return pred
118
+
119
+ # Convert image to pencil drawing
120
+ def image_to_pencil_drawing(image, line_size, line_density):
121
+ img_cv = np.array(image)
122
+ face = detect_single_face(face_cascade, img_cv)
123
+ im_face = crop_face(img_cv, face)
124
+
125
+ preprocessed_image = preprocess_image(Image.fromarray(im_face))
126
+
127
+ with torch.no_grad():
128
+ output = u2net(preprocessed_image)
129
+
130
+ output_data = output[0].squeeze().cpu().numpy()
131
+ pencil_drawing = Image.fromarray((output_data * 255).astype(np.uint8), mode='L')
132
+
133
+ img_cv = cv2.cvtColor(np.array(pencil_drawing), cv2.COLOR_GRAY2BGR)
134
+ gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
135
+ faces = face_cascade.detectMultiScale(gray, 1.1, 4)
136
+
137
+ for (x, y, w, h) in faces:
138
+ face_roi = img_cv[y:y+h, x:x+w]
139
+
140
+ pencil_drawing = Image.fromarray(cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB))
141
+ pencil_drawing = pencil_drawing.convert('RGB')
142
+ pencil_drawing = pencil_drawing.point(lambda p: p * line_density)
143
+
144
+ # Apply Gaussian Blur to adjust line size
145
+ pencil_drawing_array = np.array(pencil_drawing)
146
+ blurred = cv2.GaussianBlur(pencil_drawing_array, (line_size, line_size), 0)
147
+ inverted_image = 255 - blurred
148
+ inverted_pencil_drawing = Image.fromarray(inverted_image.astype(np.uint8))
149
+
150
+ return inverted_pencil_drawing
151
+
152
+ # Fix image function to handle default and uploaded images
153
+ def fix_image(upload=None):
154
+ if upload:
155
+ image = Image.open(upload)
156
+ else:
157
+ image = Image.open("8.jpg") # Default image path
158
+
159
+ # Create columns for side-by-side display
160
+ col1, col2 = st.columns(2)
161
+
162
+ # Display the original image
163
+ with col1:
164
+ st.image(image, caption="Uploaded Image", use_column_width=True)
165
+
166
+ # Process and display the pencil drawing
167
+ pencil_drawing = image_to_pencil_drawing(image, line_size, line_density)
168
+
169
+ with col2:
170
+ st.image(pencil_drawing, caption="Pencil Drawing", use_column_width=True)
171
+
172
+ # Provide download option for the pencil drawing
173
+ buf = io.BytesIO()
174
+ pencil_drawing.save(buf, format="PNG")
175
+ byte_im = buf.getvalue()
176
+ st.sidebar.download_button(
177
+ label="Download Pencil Drawing",
178
+ data=byte_im,
179
+ file_name="pencil_drawing.png",
180
+ mime="image/png"
181
+ )
182
+
183
+ # Streamlit app setup
184
+ st.title("Image to Pencil Drawing")
185
+
186
+ # Sidebar for file upload and controls
187
+ st.sidebar.title("Controls :gear:")
188
+ uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
189
+
190
+ # Slider controls for line size and density
191
+ line_size = st.sidebar.slider("Line Size", min_value=1, max_value=15, value=5, step=2)
192
+ line_density = st.sidebar.slider("Line Density", min_value=0.5, max_value=2.0, value=1.0, step=0.1)
193
+
194
+ # Determine image processing
195
+ if uploaded_file is not None:
196
+ if uploaded_file.size > MAX_FILE_SIZE:
197
+ st.error("The uploaded file is too large. Please upload an image smaller than 5MB.")
198
+ else:
199
+ fix_image(upload=uploaded_file)
200
+ else:
201
+ fix_image() # Use default image if none uploaded
202
+
203
+ # Add custom CSS for dark theme
204
+ st.markdown(
205
+ """
206
+ <style>
207
+ body {
208
+ background-color: #1E1E1E; /* Dark background color */
209
+ color: #FFFFFF; /* White text color */
210
+ }
211
+ </style>
212
+ """,
213
+ unsafe_allow_html=True
214
+ )