Spaces:
Runtime error
Runtime error
| import av | |
| import sys | |
| import numpy as np | |
| import cv2 | |
| import streamlit as st | |
| from PIL import Image | |
| from streamlit_webrtc import WebRtcMode, webrtc_streamer | |
| sys.path.insert(1, "./retinaface") | |
| sys.path.insert(1, "./TPSMM/pkgs") | |
| from tpsmm import TPSMM | |
| from detect import Detect | |
| from turn import get_ice_servers | |
| def parse_roi_box_from_bbox(bbox, shape): | |
| img_h, img_w = shape[:2] | |
| left, top, right, bottom = bbox[:4] | |
| old_size = (right - left + bottom - top) / 2 | |
| center_x = right - (right - left) / 2.0 | |
| center_y = bottom - (bottom - top) / 2.0 + old_size * 0.14 | |
| size = int(min((old_size * 2.0) / 2, center_x, img_w-center_x, center_y, img_h-center_y) * 2.0) | |
| roi_box = [0] * 4 | |
| roi_box[0] = center_x - size / 2 | |
| roi_box[1] = center_y - size / 2 | |
| roi_box[2] = roi_box[0] + size | |
| roi_box[3] = roi_box[1] + size | |
| return roi_box | |
| cache_key = "retinaface" | |
| if cache_key in st.session_state: | |
| detector = st.session_state[cache_key] | |
| else: | |
| detector = Detect("./retinaface/weights/mobilenet0.25_epoch_842.pth", net_inshape=(486, 864)) | |
| st.session_state[cache_key] = detector | |
| cache_key = "tpsmm" | |
| if cache_key in st.session_state: | |
| generator = st.session_state[cache_key] | |
| else: | |
| generator = TPSMM() | |
| st.session_state[cache_key] = generator | |
| # type: ignore | |
| def get_images(): | |
| images = [ | |
| cv2.imread("assets/0.jpg"), | |
| cv2.imread("assets/1.jpg"), | |
| cv2.imread("assets/2.jpg"), | |
| cv2.imread("assets/3.jpg"), | |
| ] | |
| item_list = [str(i) for i in range(len(images))] | |
| images = [generator.process_source(src_img) for src_img in images] | |
| return dict(zip(item_list, images)) | |
| images = get_images() | |
| user_option = st.selectbox("Choose an item", list(images.keys())) | |
| uploaded_file = st.file_uploader("Or upload your file here...", type=['png', 'jpeg', 'jpg']) | |
| def process_file(uploaded_file): | |
| img = Image.open(uploaded_file) | |
| img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) | |
| dets = detector(img) | |
| for i, b in enumerate(dets): | |
| bbox = parse_roi_box_from_bbox(b[:4], img.shape) | |
| bbox = [int(i) for i in bbox] | |
| face_img = img[bbox[1]:bbox[3], bbox[0]:bbox[2]].copy() | |
| # cv2.imwrite("./tmp.jpg", face_img) | |
| return generator.process_source(face_img) | |
| return None | |
| if uploaded_file is not None: | |
| uploaded_file = process_file(uploaded_file) | |
| def callback(frame: av.VideoFrame) -> av.VideoFrame: | |
| img = frame.to_ndarray(format="bgr24") | |
| try: | |
| dets = detector(img) | |
| output = None | |
| for i, b in enumerate(dets): | |
| text = "{:.4f}".format(b[4]) | |
| b = b.astype(np.int32) | |
| cv2.rectangle(img, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2) | |
| bbox = parse_roi_box_from_bbox(b[:4], img.shape) | |
| bbox = [int(i) for i in bbox] | |
| cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (255, 0, 0), 2) | |
| face_img = img[bbox[1]:bbox[3], bbox[0]:bbox[2]].copy() | |
| if uploaded_file is None: | |
| source_tensor, kp_source = images[user_option] | |
| else: | |
| source_tensor, kp_source = uploaded_file | |
| output = generator.gen_image(face_img, source_tensor, kp_source) | |
| landm = b[5:15] | |
| landm = landm.reshape((5, 2)) | |
| cv2.circle(img, tuple(landm[0]), 1, (0, 0, 255), 2) | |
| cv2.circle(img, tuple(landm[1]), 1, (0, 255, 255), 2) | |
| cv2.circle(img, tuple(landm[2]), 1, (255, 0, 255), 2) | |
| cv2.circle(img, tuple(landm[3]), 1, (0, 255, 0), 2) | |
| cv2.circle(img, tuple(landm[4]), 1, (255, 0, 0), 2) | |
| if output is not None: | |
| img[:256, :256] = output | |
| except Exception as e: | |
| print(e) | |
| return av.VideoFrame.from_ndarray(img, format="bgr24") | |
| webrtc_streamer( | |
| key="sample", | |
| rtc_configuration={"iceServers": get_ice_servers()}, | |
| video_frame_callback=callback, | |
| media_stream_constraints={"video": True, "audio": False}, | |
| ) |