import streamlit as st import cv2 import numpy as np import tempfile from typing import Optional, Tuple from datetime import datetime from PIL import Image from pose_detector import PoseDetector from skeleton_generator import SkeletonGenerator from animation_exporter import AnimationExporter from utils import process_video, process_image, process_gif from database import get_db, ProcessedFile, PoseData, AnimationData def init_page(): st.set_page_config(layout="wide", page_title="Pose Detection & Animation Generator") st.markdown(""" """, unsafe_allow_html=True) st.sidebar.title("Settings") confidence_threshold = st.sidebar.slider( "Detection Confidence", min_value=0.0, max_value=1.0, value=0.5, step=0.1 ) export_format = st.sidebar.selectbox( "Export Format", options=['uasset', 'fbx', 'bvh'], key='export_format' ) if "manual_correction" not in st.session_state: st.session_state.manual_correction = st.sidebar.checkbox("Enable Manual Corrections") else: st.session_state.manual_correction = st.sidebar.checkbox("Enable Manual Corrections", value=st.session_state.manual_correction) if st.session_state.manual_correction: st.sidebar.info("Интерактивная корректировка временно отключена.") st.title("Pose Detection & Animation Generator") return confidence_threshold def init_components() -> Tuple[PoseDetector, SkeletonGenerator, AnimationExporter]: return PoseDetector(), SkeletonGenerator(), AnimationExporter() def handle_upload(file_type: str, uploaded_file, components: Tuple, db_session) -> Optional[ProcessedFile]: processed_file = ProcessedFile( filename=uploaded_file.name, file_type='video' if uploaded_file.type == 'image/gif' else file_type, processing_status="processing" ) db_session.add(processed_file) db_session.commit() db_session.refresh(processed_file) return processed_file def main(): init_page() components = init_components() uploaded_file = st.file_uploader( "Choose an image or video file (max 50MB)", type=['jpg', 'jpeg', 'png', 'mp4', 'avi', 'gif'] ) if not uploaded_file: st.warning("Please upload a file to begin.") return db = next(get_db()) try: file_type = uploaded_file.type.split('/')[0] is_gif = uploaded_file.type == 'image/gif' processed_file = handle_upload(file_type, uploaded_file, components, db) col1, col2 = st.columns(2) with col1: st.subheader("Original") with col2: st.subheader("Processed") try: if file_type == 'image' and not is_gif: process_image_upload(uploaded_file, components, processed_file, db, col1, col2) else: process_video_upload(uploaded_file, components, processed_file, db, is_gif, col1, col2) except Exception as e: st.error(f"Processing error: {str(e)}") processed_file.processing_status = "failed" db.commit() return processed_file.processing_status = "completed" db.commit() except Exception as e: st.error(f"An error occurred: {str(e)}") finally: db.close() def process_image_upload(uploaded_file, components, processed_file, db, col1, col2): """Обработка изображения с сохранением состояния. Интерактивная коррекция отключена.""" pose_detector, skeleton_generator, animation_exporter = components if "uploaded_image" not in st.session_state: file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8) image = cv2.imdecode(file_bytes, 1) st.session_state.uploaded_image = image else: image = st.session_state.uploaded_image with col1: st.image(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), use_column_width=True) if "original_skeleton_data" not in st.session_state: processed_image, skeleton_data = process_image(image, pose_detector, skeleton_generator) st.session_state.original_skeleton_data = skeleton_data st.session_state.processed_image = processed_image else: skeleton_data = st.session_state.original_skeleton_data processed_image = st.session_state.processed_image save_pose_data(db, processed_file.id, skeleton_data) animation_data_binary = animation_exporter.export_pose(skeleton_data) save_animation_data(db, processed_file.id, skeleton_data) with col2: processed_rgb = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB) st.image(processed_rgb, use_column_width=True) if st.session_state.get('manual_correction', False): st.info("Интерактивная корректировка временно отключена.") def process_video_upload(uploaded_file, components, processed_file, db, is_gif, col1, col2): """Обработка видео/GIF файлов.""" pose_detector, skeleton_generator, animation_exporter = components progress_bar = st.progress(0) with tempfile.NamedTemporaryFile() as tfile: tfile.write(uploaded_file.read()) tfile.seek(0) with col1: st.video(tfile.read()) video_path = tfile.name if is_gif: processed_video_path, animation_frames = process_gif(video_path, pose_detector, skeleton_generator) else: processed_video_path, animation_frames = process_video(video_path, pose_detector, skeleton_generator) if not animation_frames: raise ValueError("No poses detected in the video/gif") save_video_data(db, processed_file.id, animation_frames) animation_data_binary = animation_exporter.export_animation(animation_frames) with col2: if processed_video_path: with open(processed_video_path, "rb") as f: st.video(f.read()) provide_download_button(animation_data_binary) def save_pose_data(db, file_id: int, skeleton_data: dict): pose_data = PoseData(file_id=file_id, landmarks=skeleton_data) db.add(pose_data) db.commit() def save_animation_data(db, file_id: int, skeleton_data: dict): animation_data = AnimationData(file_id=file_id, skeleton_data=skeleton_data) db.add(animation_data) db.commit() def save_video_data(db, file_id: int, animation_frames: list): for frame_num, frame_data in enumerate(animation_frames): pose_data = PoseData(file_id=file_id, frame_number=frame_num, landmarks=frame_data) db.add(pose_data) db.commit() def provide_download_button(animation_data_binary): st.download_button( label="Download Animation Data", data=animation_data_binary, file_name="animation.uasset", mime="application/octet-stream" ) def save_corrected_pose(db, file_id: int, joints: dict): pose_data = PoseData(file_id=file_id, landmarks=joints, corrected_landmarks=joints, is_corrected=True) db.add(pose_data) db.commit() def show_instructions(): with st.expander("Instructions"): st.markdown(""" 1. Upload an image or video using the file uploader. 2. Wait for processing to complete. 3. The **Detection Confidence** slider in the sidebar sets the minimum confidence threshold for detecting a pose. A higher value (e.g., 0.7) means that only detections with high certainty are considered, which may improve accuracy. 4. The processed image/video will be displayed in the "Processed" panel. 5. Download animation data. Supported formats: - Images: JPG, PNG - Videos: MP4, GIF """) if __name__ == "__main__": main() show_instructions()