seung0h commited on
Commit
083928f
·
1 Parent(s): 620e3e1
Files changed (2) hide show
  1. app.py +42 -0
  2. pipeline.py +211 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from pipeline import SmileGen
3
+ import torch
4
+ from PIL import Image
5
+ import numpy as np
6
+ import os
7
+
8
+
9
+ def read_samples(path):
10
+ # read the samples from the path
11
+ samples = []
12
+ for filename in os.listdir(path):
13
+ if filename.endswith(".jpg") or filename.endswith(".png"):
14
+ img = Image.open(os.path.join(path, filename))
15
+ samples.append(np.array(img))
16
+ return samples
17
+
18
+ def create_image_generation_demo():
19
+ # load sample images
20
+ image_list = []
21
+
22
+ model = SmileGen()
23
+
24
+ demo = gr.Interface(
25
+ fn=model.run,
26
+ inputs=[
27
+ gr.Image(label="Input Image", type="pil")
28
+ ],
29
+ outputs=[
30
+ gr.Image(label="Generated Image")
31
+ ],
32
+ title="Smile!",
33
+ description="Upload an image and generate a new image using a custom pipeline.",
34
+ examples=image_list
35
+ )
36
+
37
+ return demo
38
+
39
+ # Launch the demo
40
+ if __name__ == "__main__":
41
+ demo = create_image_generation_demo()
42
+ demo.launch()
pipeline.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download
2
+ from ultralytics import YOLO
3
+ from supervision import Detections
4
+ from PIL import Image
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torchvision import transforms
10
+ from PIL import Image
11
+ import pandas as pd
12
+ import torchvision.transforms as transforms
13
+ import torchvision
14
+ from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
15
+ from diffusers import AutoPipelineForInpainting
16
+
17
+
18
+ class SmileGen:
19
+ def __init__(self, device='cuda'):
20
+ self.device = device
21
+
22
+ def face_detection(self, image):
23
+ face_det = YOLO(hf_hub_download(repo_id="arnabdhar/YOLOv8-Face-Detection", filename="model.pt")).to(self.device)
24
+
25
+ face_crops = []
26
+ face_bboxs = []
27
+
28
+ output = face_det(image)
29
+ box_results = Detections.from_ultralytics(output[0])
30
+
31
+ for i, box in enumerate(box_results.xyxy):
32
+ x1, y1, x2, y2 = map(int, box.tolist()) # Convert coordinates to integers
33
+ # Crop the square by stretching small side
34
+ W, H = image.size
35
+ width = x2 - x1
36
+ height = y2 - y1
37
+ if width > height:
38
+ y1 -= (width - height) // 2
39
+ y2 += (width - height) // 2
40
+ else:
41
+ x1 -= (height - width) // 2
42
+ x2 += (height - width) // 2
43
+ x1 = max(0, x1)
44
+ y1 = max(0, y1)
45
+ x2 = min(W, x2)
46
+ y2 = min(H, y2)
47
+
48
+ box = (x1, y1, x2, y2)
49
+ face_crop = image.crop(box) # Crop the region
50
+ face_crops.append(face_crop)
51
+ face_bboxs.append(box)
52
+
53
+ return face_crops, face_bboxs
54
+
55
+ def face_classification(self, face_crops, face_bboxs):
56
+
57
+ face_classifier = torchvision.models.efficientnet_b0(pretrained=True)
58
+ num_features = face_classifier.classifier[1].in_features
59
+ face_classifier.classifier[1] = nn.Linear(num_features, 2)
60
+
61
+ hf_path = hf_hub_download(
62
+ repo_id="seung0h/smile_classification",
63
+ filename="best_efficientnetB0_smile.pth",
64
+ )
65
+ best_ckpt = torch.load(hf_path)
66
+ face_classifier.load_state_dict(best_ckpt)
67
+ face_classifier.to(self.device)
68
+
69
+ face_classifier.eval()
70
+
71
+ val_transforms = transforms.Compose([
72
+ transforms.Resize((224, 224)),
73
+ transforms.ToTensor(),
74
+ transforms.Normalize([0.485, 0.456, 0.406],
75
+ [0.229, 0.224, 0.225])
76
+ ])
77
+
78
+ unsmile_imgs = []
79
+ unsmile_boxes = []
80
+
81
+ for i, img in enumerate(face_crops):
82
+ # img = Image.fromarray(img)
83
+ img_tensor = val_transforms(img).unsqueeze(0).to(self.device)
84
+
85
+ with torch.no_grad():
86
+ output = face_classifier(img_tensor)
87
+ _, pred = torch.max(output, 1)
88
+ pred_label = pred.item()
89
+ result = "Smile" if pred_label == 1 else "Not smile"
90
+
91
+ if result == "Not smile":
92
+ unsmile_imgs.append(img)
93
+ unsmile_boxes.append(face_bboxs[i])
94
+
95
+ return unsmile_imgs, unsmile_boxes
96
+
97
+ def gen_mask(self, unsmile_imgs):
98
+
99
+ seg_processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
100
+ face_parser = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing").to(self.device)
101
+
102
+ mask_list = []
103
+ label_prior = {10, 11, 12}
104
+
105
+ for image in unsmile_imgs:
106
+ min_x, min_y = 1000, 1000
107
+ max_x, max_y = 0, 0
108
+
109
+ inputs = seg_processor(images=image, return_tensors="pt").to(self.device)
110
+ outputs = face_parser(**inputs)
111
+ logits = outputs.logits
112
+
113
+ # resize output to match input image dimensions
114
+ upsampled_logits = nn.functional.interpolate(logits,
115
+ size=image.size[::-1], # H x W
116
+ mode='bilinear',
117
+ align_corners=False)
118
+
119
+ # get label masks
120
+ labels = upsampled_logits.argmax(dim=1)[0]
121
+ mask = np.zeros(labels.shape)
122
+
123
+ for i in range(labels.shape[0]):
124
+ for j in range(labels.shape[1]):
125
+ # Check if the current label is in the predefined set
126
+ if labels[i][j].item() in label_prior:
127
+ # Update minimum and maximum coordinates
128
+ min_x = min(min_x, i)
129
+ min_y = min(min_y, j)
130
+ max_x = max(max_x, i)
131
+ max_y = max(max_y, j)
132
+
133
+ # Create a mask by setting the bounding box region to 255 (white)
134
+ delta = 15
135
+ mask[min_x-delta:max_x+delta, min_y-delta:max_y+delta] = 255
136
+
137
+ center_x = (min_x + max_x) // 2
138
+ center_y = (min_y + max_y) // 2
139
+
140
+ # open the center of lips for style consistency
141
+ hole_size = (max_y-min_y)//4
142
+ mask_copy = mask.copy()
143
+ mask_copy[:, center_y-hole_size:center_y+hole_size] = 0
144
+ mask_list.append({"mask":mask,
145
+ "hole_mask": mask_copy})
146
+
147
+ return mask_list
148
+
149
+ def kan_inference(self, unsmile_imgs, mask_list):
150
+ prompt = (
151
+ "a young korean person, smiling softly, mouth closed or gently open, "
152
+ "natural lips, realistic lighting, close-up portrait, high quality, professional studio photo"
153
+ )
154
+ negative_prompt = (
155
+ "bad anatomy, deformed lips, extra mouth, open mouth showing teeth, "
156
+ "distorted face, blurry, low quality, erotic, nsfw, sexual, nude, cleavage, extra face"
157
+ )
158
+ results=[]
159
+
160
+ generator = torch.Generator(device=self.device).manual_seed(42)
161
+ pipe = AutoPipelineForInpainting.from_pretrained(
162
+ "kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16
163
+ ).to(self.device)
164
+
165
+ for img, m in zip(unsmile_imgs, mask_list):
166
+ mask = m["hole_mask"]
167
+ mask_image = Image.fromarray(mask).resize((512, 512))
168
+ init_image = img.resize((512, 512))
169
+
170
+ result = pipe(prompt=prompt,
171
+ negative_prompt=negative_prompt,
172
+ num_inference_steps=20,
173
+ generator=generator,
174
+ image=init_image,
175
+ mask_image=mask_image).images[0]
176
+
177
+ results.append(result)
178
+
179
+ return results
180
+
181
+ def make_result(self, image_orig, results, unsmile_imgs, unsmile_boxes):
182
+ image_restored = image_orig.copy()
183
+
184
+ for i, result in enumerate(results):
185
+ orig_crop = unsmile_imgs[i]
186
+
187
+ box = unsmile_boxes[i]
188
+ x1, y1, x2, y2 = box
189
+ w, h = x2 - x1, y2 - y1
190
+
191
+ gen_image = results[i].resize((w, h))
192
+ image_restored.paste(gen_image, box=(x1, y1))
193
+
194
+ return image_restored
195
+
196
+ def run(self, image):
197
+ face_crops, face_bboxs = self.face_detection(image)
198
+
199
+ unsmile_imgs, unsmile_boxes = self.face_classification(face_crops, face_bboxs)
200
+ mask_list = self.gen_mask(unsmile_imgs)
201
+ results = self.kan_inference(unsmile_imgs, mask_list)
202
+ image_restored = self.make_result(image, results, unsmile_imgs, unsmile_boxes)
203
+
204
+ return image_restored
205
+
206
+
207
+ if __name__ == "__main__":
208
+ smile_gen = SmileGen()
209
+ image = Image.open("samples/newjeans.jpg")
210
+ result = smile_gen.run(image)
211
+ result.show()