Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import cv2 | |
| import numpy as np | |
| import tempfile | |
| from typing import Optional, Tuple | |
| from datetime import datetime | |
| from PIL import Image | |
| from pose_detector import PoseDetector | |
| from skeleton_generator import SkeletonGenerator | |
| from animation_exporter import AnimationExporter | |
| from utils import process_video, process_image, process_gif | |
| from database import get_db, ProcessedFile, PoseData, AnimationData | |
| def init_page(): | |
| st.set_page_config(layout="wide", page_title="Pose Detection & Animation Generator") | |
| st.markdown(""" | |
| <style> | |
| /* Base styling */ | |
| .stApp { | |
| max-width: 100% !important; | |
| padding: 2rem; | |
| transition: background-color 0.3s; | |
| } | |
| [data-theme="dark"] .stApp { | |
| background-color: #1E1E1E; | |
| color: #FFFFFF; | |
| } | |
| .stButton>button { | |
| background-color: #4CAF50; | |
| color: white; | |
| border-radius: 4px; | |
| padding: 0.5rem 1rem; | |
| } | |
| .stDownloadButton>button { | |
| background-color: #008CBA; | |
| } | |
| .stProgress > div > div > div { | |
| background-color: #4CAF50; | |
| } | |
| .stColumn { | |
| padding: 0 1rem; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| st.sidebar.title("Settings") | |
| confidence_threshold = st.sidebar.slider( | |
| "Detection Confidence", | |
| min_value=0.0, | |
| max_value=1.0, | |
| value=0.5, | |
| step=0.1 | |
| ) | |
| export_format = st.sidebar.selectbox( | |
| "Export Format", | |
| options=['uasset', 'fbx', 'bvh'], | |
| key='export_format' | |
| ) | |
| if "manual_correction" not in st.session_state: | |
| st.session_state.manual_correction = st.sidebar.checkbox("Enable Manual Corrections") | |
| else: | |
| st.session_state.manual_correction = st.sidebar.checkbox("Enable Manual Corrections", value=st.session_state.manual_correction) | |
| if st.session_state.manual_correction: | |
| st.sidebar.info("Интерактивная корректировка временно отключена.") | |
| st.title("Pose Detection & Animation Generator") | |
| return confidence_threshold | |
| def init_components() -> Tuple[PoseDetector, SkeletonGenerator, AnimationExporter]: | |
| return PoseDetector(), SkeletonGenerator(), AnimationExporter() | |
| def handle_upload(file_type: str, uploaded_file, components: Tuple, db_session) -> Optional[ProcessedFile]: | |
| processed_file = ProcessedFile( | |
| filename=uploaded_file.name, | |
| file_type='video' if uploaded_file.type == 'image/gif' else file_type, | |
| processing_status="processing" | |
| ) | |
| db_session.add(processed_file) | |
| db_session.commit() | |
| db_session.refresh(processed_file) | |
| return processed_file | |
| def main(): | |
| init_page() | |
| components = init_components() | |
| uploaded_file = st.file_uploader( | |
| "Choose an image or video file (max 50MB)", | |
| type=['jpg', 'jpeg', 'png', 'mp4', 'avi', 'gif'] | |
| ) | |
| if not uploaded_file: | |
| st.warning("Please upload a file to begin.") | |
| return | |
| db = next(get_db()) | |
| try: | |
| file_type = uploaded_file.type.split('/')[0] | |
| is_gif = uploaded_file.type == 'image/gif' | |
| processed_file = handle_upload(file_type, uploaded_file, components, db) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.subheader("Original") | |
| with col2: | |
| st.subheader("Processed") | |
| try: | |
| if file_type == 'image' and not is_gif: | |
| process_image_upload(uploaded_file, components, processed_file, db, col1, col2) | |
| else: | |
| process_video_upload(uploaded_file, components, processed_file, db, is_gif, col1, col2) | |
| except Exception as e: | |
| st.error(f"Processing error: {str(e)}") | |
| processed_file.processing_status = "failed" | |
| db.commit() | |
| return | |
| processed_file.processing_status = "completed" | |
| db.commit() | |
| except Exception as e: | |
| st.error(f"An error occurred: {str(e)}") | |
| finally: | |
| db.close() | |
| def process_image_upload(uploaded_file, components, processed_file, db, col1, col2): | |
| """Обработка изображения с сохранением состояния. Интерактивная коррекция отключена.""" | |
| pose_detector, skeleton_generator, animation_exporter = components | |
| if "uploaded_image" not in st.session_state: | |
| file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8) | |
| image = cv2.imdecode(file_bytes, 1) | |
| st.session_state.uploaded_image = image | |
| else: | |
| image = st.session_state.uploaded_image | |
| with col1: | |
| st.image(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), use_column_width=True) | |
| if "original_skeleton_data" not in st.session_state: | |
| processed_image, skeleton_data = process_image(image, pose_detector, skeleton_generator) | |
| st.session_state.original_skeleton_data = skeleton_data | |
| st.session_state.processed_image = processed_image | |
| else: | |
| skeleton_data = st.session_state.original_skeleton_data | |
| processed_image = st.session_state.processed_image | |
| save_pose_data(db, processed_file.id, skeleton_data) | |
| animation_data_binary = animation_exporter.export_pose(skeleton_data) | |
| save_animation_data(db, processed_file.id, skeleton_data) | |
| with col2: | |
| processed_rgb = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB) | |
| st.image(processed_rgb, use_column_width=True) | |
| if st.session_state.get('manual_correction', False): | |
| st.info("Интерактивная корректировка временно отключена.") | |
| def process_video_upload(uploaded_file, components, processed_file, db, is_gif, col1, col2): | |
| """Обработка видео/GIF файлов.""" | |
| pose_detector, skeleton_generator, animation_exporter = components | |
| progress_bar = st.progress(0) | |
| with tempfile.NamedTemporaryFile() as tfile: | |
| tfile.write(uploaded_file.read()) | |
| tfile.seek(0) | |
| with col1: | |
| st.video(tfile.read()) | |
| video_path = tfile.name | |
| if is_gif: | |
| processed_video_path, animation_frames = process_gif(video_path, pose_detector, skeleton_generator) | |
| else: | |
| processed_video_path, animation_frames = process_video(video_path, pose_detector, skeleton_generator) | |
| if not animation_frames: | |
| raise ValueError("No poses detected in the video/gif") | |
| save_video_data(db, processed_file.id, animation_frames) | |
| animation_data_binary = animation_exporter.export_animation(animation_frames) | |
| with col2: | |
| if processed_video_path: | |
| with open(processed_video_path, "rb") as f: | |
| st.video(f.read()) | |
| provide_download_button(animation_data_binary) | |
| def save_pose_data(db, file_id: int, skeleton_data: dict): | |
| pose_data = PoseData(file_id=file_id, landmarks=skeleton_data) | |
| db.add(pose_data) | |
| db.commit() | |
| def save_animation_data(db, file_id: int, skeleton_data: dict): | |
| animation_data = AnimationData(file_id=file_id, skeleton_data=skeleton_data) | |
| db.add(animation_data) | |
| db.commit() | |
| def save_video_data(db, file_id: int, animation_frames: list): | |
| for frame_num, frame_data in enumerate(animation_frames): | |
| pose_data = PoseData(file_id=file_id, frame_number=frame_num, landmarks=frame_data) | |
| db.add(pose_data) | |
| db.commit() | |
| def provide_download_button(animation_data_binary): | |
| st.download_button( | |
| label="Download Animation Data", | |
| data=animation_data_binary, | |
| file_name="animation.uasset", | |
| mime="application/octet-stream" | |
| ) | |
| def save_corrected_pose(db, file_id: int, joints: dict): | |
| pose_data = PoseData(file_id=file_id, landmarks=joints, corrected_landmarks=joints, is_corrected=True) | |
| db.add(pose_data) | |
| db.commit() | |
| def show_instructions(): | |
| with st.expander("Instructions"): | |
| st.markdown(""" | |
| 1. Upload an image or video using the file uploader. | |
| 2. Wait for processing to complete. | |
| 3. The **Detection Confidence** slider in the sidebar sets the minimum confidence threshold for detecting a pose. | |
| A higher value (e.g., 0.7) means that only detections with high certainty are considered, which may improve accuracy. | |
| 4. The processed image/video will be displayed in the "Processed" panel. | |
| 5. Download animation data. | |
| Supported formats: | |
| - Images: JPG, PNG | |
| - Videos: MP4, GIF | |
| """) | |
| if __name__ == "__main__": | |
| main() | |
| show_instructions() |