VASR_ / src /app.py
sophiemaw's picture
Update src/app.py
b3b6a59 verified
import streamlit as st
import os
import tempfile
from PIL import Image
# VASR pipeline imports
from generate import generate_face_image
from faceswap import face_swap_multiple_identities
from video_utils import extract_frames, recombine_frames
from face_detection import detect_faces_insightface_from_frames, extract_faces_from_frames_folder
from grouping import group_faces_by_identity_facenet_fixed, split_csv_by_identity
from load_models import warm_up_models
from main import fix_faces
from image_utils import add_watermark_to_frames
# Streamlit config
st.set_page_config(page_title="VASR", layout="centered")
st.title("🧬 VASR - Video Anonymisation")
# Create a Hugging Face writable temp directory
# Keep /tmp directory consistent across reruns
if "hf_tmp" not in st.session_state:
st.session_state.hf_tmp = tempfile.mkdtemp(dir="/tmp")
HF_TMP = st.session_state.hf_tmp
# Define output paths
input_video_path = os.path.join(HF_TMP, "input.mp4")
frames_folder = os.path.join(HF_TMP, "frames")
csv_path = os.path.join(HF_TMP, "frames_detections.csv")
output_faces = os.path.join(HF_TMP, "output_faces")
grouped_faces_dir = os.path.join(HF_TMP, "grouped_faces_facenet")
identity_csv_dir = os.path.join(HF_TMP, "identity_csvs")
output_frames = os.path.join(HF_TMP, "output_frames")
param_log_path = os.path.join(HF_TMP, "grouping_params.txt")
output_video_path = os.path.join(HF_TMP, "anonymised_output.mp4")
# UI elements
uploaded_video = st.file_uploader("πŸ“Ή Upload video file", type=["mp4", "mov"])
num_ids = st.number_input("πŸ‘₯ How many people are in the video?", min_value=1, value=2)
warm_up_models()
identitified = False
if uploaded_video: # βœ… read once
try:
with open(input_video_path, "wb") as f:
f.write(uploaded_video.read())
st.success("βœ… File saved")
# st.video(temp_path)
except Exception as e:
st.error(f"❌ Failed to write video: {e}")
if st.button("πŸš€ Preprocess video"):
progress = st.progress(0)
with st.spinner("🧠 Warming up models..."):
warm_up_models()
progress.progress(5)
with st.spinner("πŸ“½ Extracting frames..."):
extract_frames(input_video_path, frames_folder)
progress.progress(10)
with st.spinner("πŸ•΅οΈ Detecting faces..."):
detect_faces_insightface_from_frames(
frames_folder,
csv_path,
streamlit_progress=progress,
progress_range=(10, 70)
)
progress.progress(70)
with st.spinner("βœ‚οΈ Extracting face crops..."):
extract_faces_from_frames_folder(
frames_folder,
csv_path,
output_faces,
streamlit_progress=progress,
progress_range=(70, 75)
)
progress.progress(75)
with st.spinner("πŸ‘₯ Grouping identities..."):
identitified = group_faces_by_identity_facenet_fixed(
faces_folder=output_faces,
output_folder=grouped_faces_dir,
num_identities=num_ids,
streamlit_progress=progress,
progress_range=(75, 95),
param_log_path=param_log_path
)
if identitified:
split_csv_by_identity(
original_csv=csv_path,
grouped_folder=grouped_faces_dir,
output_dir=identity_csv_dir
)
st.write("πŸ“ Grouped identities:", os.listdir(grouped_faces_dir))
for subfolder in sorted(os.listdir(grouped_faces_dir)):
subdir_path = os.path.join(grouped_faces_dir, subfolder)
st.write(f" - {subfolder}: {len(os.listdir(subdir_path))} images")
progress.progress(95)
st.session_state.identity_csv_paths = [
os.path.join(identity_csv_dir, f) for f in sorted(os.listdir(identity_csv_dir)) if f.endswith(".csv")
]
st.session_state.grouped_faces_dir = grouped_faces_dir
st.session_state.generated_faces = [None] * num_ids
st.session_state.identity_index = 0
st.session_state.video_ready = True
st.success("βœ… Ready to generate anonymised faces.")
progress.progress(100)
else:
st.error(f"❌ Unable to find {num_ids} distinct people in the video. "
f"The video may be too short or too low quality. Please try another.")
# Identity selection loop
if "generated_faces" in st.session_state:
current_index = st.session_state.identity_index
total_faces = len(st.session_state.generated_faces)
st.text(f"πŸ“ Current identity index: {current_index}") # Debug
if current_index < total_faces:
st.header(f"🧬 Identity {current_index + 1}/{total_faces}")
# Show original sample image
grouped_faces_dir = st.session_state.grouped_faces_dir
identity_folders = sorted([
f for f in os.listdir(grouped_faces_dir)
if f.startswith("identity_")
])
if current_index < len(identity_folders):
sample_dir = os.path.join(grouped_faces_dir, identity_folders[current_index])
else:
st.warning("❌ No identity folder found for this index.")
sample_dir = None
if sample_dir and os.path.exists(sample_dir):
st.write(f"πŸ“‚ Using sample directory: {sample_dir}")
if os.path.exists(sample_dir):
image_files = sorted([
f for f in os.listdir(sample_dir)
if f.lower().endswith(('.jpg', '.png'))
])
if image_files:
sample_path = os.path.join(sample_dir, image_files[0])
else:
st.warning(f"❌ Folder not found: {sample_dir}")
if sample_path and os.path.exists(sample_path):
st.image(sample_path, caption="🎭 Original Face", width=256)
else:
st.warning("⚠️ No sample image found for this identity.")
# Show current generated face if exists
if st.session_state.generated_faces[current_index] is not None:
st.image(st.session_state.generated_faces[current_index], caption="🧬 Generated Face", width=256)
# Generate new face button (with unique key)
if st.button("πŸ” Generate New Identity", key=f"generate_{current_index}"):
with st.spinner("🧠 Generating new face..."):
new_face = generate_face_image()
st.session_state.generated_faces[current_index] = new_face
st.rerun()
# Confirm and move to next identity (with rerun)
if st.button("βœ… Confirm & Next Identity", key=f"next_{current_index}"):
if st.session_state.generated_faces[current_index] is None:
st.warning("⚠️ Please generate a face before continuing.")
else:
st.session_state.identity_index += 1
st.experimental_rerun()
# If all identities confirmed, show anonymise option
if current_index >= total_faces:
st.success("βœ… All identities reviewed.")
if st.button("πŸš€ Anonymise Video"):
swap_progress = st.progress(0)
with st.spinner("πŸ“‘ Fixing CSVs..."):
identity_csvs = fix_faces(total_faces, st.session_state.identity_csv_paths)
swap_progress.progress(1)
with st.spinner("πŸ€– Swapping faces..."):
face_swap_multiple_identities(
frame_folder=frames_folder,
output_folder=output_frames,
generated_images=st.session_state.generated_faces,
identity_csv_paths=identity_csvs,
streamlit_progress=swap_progress,
progress_range=(1, 96)
)
add_watermark_to_frames(output_frames, output_frames)
swap_progress.progress(96)
with st.spinner("🎞 Recombining video..."):
recombine_frames(
input_video_path,
frames_folder=output_frames,
output_video=output_video_path
)
swap_progress.progress(100)
st.success("πŸŽ‰ Anonymised video complete!")
st.video(output_video_path, format="video/mp4")