thaint2901's picture
init
f3261a0
import av
import sys
import numpy as np
import cv2
import streamlit as st
from PIL import Image
from streamlit_webrtc import WebRtcMode, webrtc_streamer
sys.path.insert(1, "./retinaface")
sys.path.insert(1, "./TPSMM/pkgs")
from tpsmm import TPSMM
from detect import Detect
from turn import get_ice_servers
def parse_roi_box_from_bbox(bbox, shape):
img_h, img_w = shape[:2]
left, top, right, bottom = bbox[:4]
old_size = (right - left + bottom - top) / 2
center_x = right - (right - left) / 2.0
center_y = bottom - (bottom - top) / 2.0 + old_size * 0.14
size = int(min((old_size * 2.0) / 2, center_x, img_w-center_x, center_y, img_h-center_y) * 2.0)
roi_box = [0] * 4
roi_box[0] = center_x - size / 2
roi_box[1] = center_y - size / 2
roi_box[2] = roi_box[0] + size
roi_box[3] = roi_box[1] + size
return roi_box
cache_key = "retinaface"
if cache_key in st.session_state:
detector = st.session_state[cache_key]
else:
detector = Detect("./retinaface/weights/mobilenet0.25_epoch_842.pth", net_inshape=(486, 864))
st.session_state[cache_key] = detector
cache_key = "tpsmm"
if cache_key in st.session_state:
generator = st.session_state[cache_key]
else:
generator = TPSMM()
st.session_state[cache_key] = generator
@st.cache_resource # type: ignore
def get_images():
images = [
cv2.imread("assets/0.jpg"),
cv2.imread("assets/1.jpg"),
cv2.imread("assets/2.jpg"),
cv2.imread("assets/3.jpg"),
]
item_list = [str(i) for i in range(len(images))]
images = [generator.process_source(src_img) for src_img in images]
return dict(zip(item_list, images))
images = get_images()
user_option = st.selectbox("Choose an item", list(images.keys()))
uploaded_file = st.file_uploader("Or upload your file here...", type=['png', 'jpeg', 'jpg'])
@st.cache_resource
def process_file(uploaded_file):
img = Image.open(uploaded_file)
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
dets = detector(img)
for i, b in enumerate(dets):
bbox = parse_roi_box_from_bbox(b[:4], img.shape)
bbox = [int(i) for i in bbox]
face_img = img[bbox[1]:bbox[3], bbox[0]:bbox[2]].copy()
# cv2.imwrite("./tmp.jpg", face_img)
return generator.process_source(face_img)
return None
if uploaded_file is not None:
uploaded_file = process_file(uploaded_file)
def callback(frame: av.VideoFrame) -> av.VideoFrame:
img = frame.to_ndarray(format="bgr24")
try:
dets = detector(img)
output = None
for i, b in enumerate(dets):
text = "{:.4f}".format(b[4])
b = b.astype(np.int32)
cv2.rectangle(img, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2)
bbox = parse_roi_box_from_bbox(b[:4], img.shape)
bbox = [int(i) for i in bbox]
cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (255, 0, 0), 2)
face_img = img[bbox[1]:bbox[3], bbox[0]:bbox[2]].copy()
if uploaded_file is None:
source_tensor, kp_source = images[user_option]
else:
source_tensor, kp_source = uploaded_file
output = generator.gen_image(face_img, source_tensor, kp_source)
landm = b[5:15]
landm = landm.reshape((5, 2))
cv2.circle(img, tuple(landm[0]), 1, (0, 0, 255), 2)
cv2.circle(img, tuple(landm[1]), 1, (0, 255, 255), 2)
cv2.circle(img, tuple(landm[2]), 1, (255, 0, 255), 2)
cv2.circle(img, tuple(landm[3]), 1, (0, 255, 0), 2)
cv2.circle(img, tuple(landm[4]), 1, (255, 0, 0), 2)
if output is not None:
img[:256, :256] = output
except Exception as e:
print(e)
return av.VideoFrame.from_ndarray(img, format="bgr24")
webrtc_streamer(
key="sample",
rtc_configuration={"iceServers": get_ice_servers()},
video_frame_callback=callback,
media_stream_constraints={"video": True, "audio": False},
)