Carzit commited on
Commit
ba6a721
·
verified ·
1 Parent(s): 409afcb

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ from PIL import Image
5
+ import torch
6
+ import torch.backends.cudnn as cudnn
7
+ from numpy import random
8
+
9
+ from models.experimental import attempt_load
10
+ from utils.datasets import LoadStreams, LoadImages
11
+ from utils.general import (
12
+ check_img_size, non_max_suppression, apply_classifier, scale_coords, xyxy2xywh, plot_one_box, strip_optimizer)
13
+ from utils.torch_utils import select_device, load_classifier, time_synchronized
14
+
15
+ import gradio as gr
16
+ import huggingface_hub
17
+
18
+ from crop import crop
19
+
20
+ class FaceCrop:
21
+ def __init__(self):
22
+ self.device = select_device()
23
+ self.half = self.device.type != 'cpu'
24
+ self.results = {}
25
+
26
+ def load_dataset(self, source):
27
+ self.source = source
28
+ self.dataset = LoadImages(source)
29
+ print(f'Successfully load {source}')
30
+
31
+ def load_model(self, model):
32
+ self.model = attempt_load(model, map_location=self.device)
33
+ if self.half:
34
+ self.model.half()
35
+ print(f'Successfully load model weights from {model}')
36
+
37
+ def set_crop_config(self, target_size, mode=0, face_ratio=3, threshold=1.5):
38
+ self.target_size = target_size
39
+ self.mode = mode
40
+ self.face_ratio = face_ratio
41
+ self.threshold = threshold
42
+
43
+ def info(self):
44
+ attributes = dir(self)
45
+ for attribute in attributes:
46
+ if not attribute.startswith('__') and not callable(getattr(self, attribute)):
47
+ value = getattr(self, attribute)
48
+ print(attribute, " = ", value)
49
+
50
+ def process(self):
51
+ for path, img, im0s, vid_cap in self.dataset:
52
+ img = torch.from_numpy(img).to(self.device)
53
+ img = img.half() if self.half else img.float() # uint8 to fp16/32
54
+ img /= 255.0 # 0 - 255 to 0.0 - 1.0
55
+ if img.ndimension() == 3:
56
+ img = img.unsqueeze(0)
57
+
58
+ # Inference
59
+ pred = self.model(img, augment=False)[0]
60
+
61
+ # Apply NMS
62
+ pred = non_max_suppression(pred)
63
+
64
+ # Process detections
65
+ for i, det in enumerate(pred): # detections per image
66
+
67
+ p, s, im0 = path, '', im0s
68
+
69
+ in_path = str(Path(self.source) / Path(p).name)
70
+
71
+ #txt_path = str(Path(out) / Path(p).stem)
72
+ s += '%gx%g ' % img.shape[2:] # print string
73
+ gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
74
+
75
+ if det is not None and len(det):
76
+ # Rescale boxes from img_size to im0 size
77
+ det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
78
+
79
+ # Write results
80
+ ind = 0
81
+ for *xyxy, conf, cls in det:
82
+ if conf > 0.6: # Write to file
83
+ out_path = os.path.join(str(Path(self.out_folder)), Path(p).name.replace('.', '_'+str(ind)+'.'))
84
+
85
+ x, y, w, h = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()
86
+ self.results[ind] = crop(in_path, (x, y), out_path, mode=self.mode, size=self.target_size, box=(w, h), face_ratio=self.face_ratio, shreshold=self.threshold)
87
+
88
+ ind += 1
89
+
90
+ def run(img, mode, width, height):
91
+ face_crop_pipeline.load_dataset(img)
92
+ face_crop_pipeline.set_crop_config(mode=mode, target_size=(width,height))
93
+ face_crop_pipeline.process
94
+ return face_crop_pipeline.results[0]
95
+
96
+ if __name__ == '__main__':
97
+ model_path = huggingface_hub.hf_hub_download("Carzit/yolo5x_anime", "yolo5x_anime.pt")
98
+ face_crop_pipeline = FaceCrop()
99
+ face_crop_pipeline.load_model(model_path)
100
+
101
+
102
+ app = gr.Blocks()
103
+ with app:
104
+ gr.Markdown("# Anime Face Crop\n\n"
105
+ "![visitor badge](https://visitor-badge.glitch.me/badge?page_id=skytnt.animeseg)\n\n"
106
+ "demo for [https://github.com/SkyTNT/anime-segmentation/](https://github.com/SkyTNT/anime-segmentation/)")
107
+ with gr.Row():
108
+ input_img = gr.Image(label="input image")
109
+ output_img = gr.Image(label="result", image_mode="RGB")
110
+ crop_mode = gr.Dropdown([0, 1, 2, 3], label="Crop Mode", info="0:Auto; 1:No Scale; 2:Full Screen; 3:Fixed Face Ratio")
111
+ tgt_width = gr.Slider(10, 2048, value=512, label="Width")
112
+ tgt_height = gr.Slider(10, 2048, value=512, label="Height")
113
+
114
+ run_btn = gr.Button(variant="primary")
115
+
116
+ run_btn.click(run, [input_img, crop_mode, tgt_width, tgt_height], [output_img])
117
+ app.launch()
118
+
119
+
120
+