import av import sys import numpy as np import cv2 import streamlit as st from PIL import Image from streamlit_webrtc import WebRtcMode, webrtc_streamer sys.path.insert(1, "./retinaface") sys.path.insert(1, "./TPSMM/pkgs") from tpsmm import TPSMM from detect import Detect from turn import get_ice_servers def parse_roi_box_from_bbox(bbox, shape): img_h, img_w = shape[:2] left, top, right, bottom = bbox[:4] old_size = (right - left + bottom - top) / 2 center_x = right - (right - left) / 2.0 center_y = bottom - (bottom - top) / 2.0 + old_size * 0.14 size = int(min((old_size * 2.0) / 2, center_x, img_w-center_x, center_y, img_h-center_y) * 2.0) roi_box = [0] * 4 roi_box[0] = center_x - size / 2 roi_box[1] = center_y - size / 2 roi_box[2] = roi_box[0] + size roi_box[3] = roi_box[1] + size return roi_box cache_key = "retinaface" if cache_key in st.session_state: detector = st.session_state[cache_key] else: detector = Detect("./retinaface/weights/mobilenet0.25_epoch_842.pth", net_inshape=(486, 864)) st.session_state[cache_key] = detector cache_key = "tpsmm" if cache_key in st.session_state: generator = st.session_state[cache_key] else: generator = TPSMM() st.session_state[cache_key] = generator @st.cache_resource # type: ignore def get_images(): images = [ cv2.imread("assets/0.jpg"), cv2.imread("assets/1.jpg"), cv2.imread("assets/2.jpg"), cv2.imread("assets/3.jpg"), ] item_list = [str(i) for i in range(len(images))] images = [generator.process_source(src_img) for src_img in images] return dict(zip(item_list, images)) images = get_images() user_option = st.selectbox("Choose an item", list(images.keys())) uploaded_file = st.file_uploader("Or upload your file here...", type=['png', 'jpeg', 'jpg']) @st.cache_resource def process_file(uploaded_file): img = Image.open(uploaded_file) img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) dets = detector(img) for i, b in enumerate(dets): bbox = parse_roi_box_from_bbox(b[:4], img.shape) bbox = [int(i) for i in bbox] face_img = img[bbox[1]:bbox[3], bbox[0]:bbox[2]].copy() # cv2.imwrite("./tmp.jpg", face_img) return generator.process_source(face_img) return None if uploaded_file is not None: uploaded_file = process_file(uploaded_file) def callback(frame: av.VideoFrame) -> av.VideoFrame: img = frame.to_ndarray(format="bgr24") try: dets = detector(img) output = None for i, b in enumerate(dets): text = "{:.4f}".format(b[4]) b = b.astype(np.int32) cv2.rectangle(img, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2) bbox = parse_roi_box_from_bbox(b[:4], img.shape) bbox = [int(i) for i in bbox] cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (255, 0, 0), 2) face_img = img[bbox[1]:bbox[3], bbox[0]:bbox[2]].copy() if uploaded_file is None: source_tensor, kp_source = images[user_option] else: source_tensor, kp_source = uploaded_file output = generator.gen_image(face_img, source_tensor, kp_source) landm = b[5:15] landm = landm.reshape((5, 2)) cv2.circle(img, tuple(landm[0]), 1, (0, 0, 255), 2) cv2.circle(img, tuple(landm[1]), 1, (0, 255, 255), 2) cv2.circle(img, tuple(landm[2]), 1, (255, 0, 255), 2) cv2.circle(img, tuple(landm[3]), 1, (0, 255, 0), 2) cv2.circle(img, tuple(landm[4]), 1, (255, 0, 0), 2) if output is not None: img[:256, :256] = output except Exception as e: print(e) return av.VideoFrame.from_ndarray(img, format="bgr24") webrtc_streamer( key="sample", rtc_configuration={"iceServers": get_ice_servers()}, video_frame_callback=callback, media_stream_constraints={"video": True, "audio": False}, )