Spaces:
Running
Running
| import cv2 | |
| import tempfile | |
| import numpy as np | |
| import os | |
| import streamlit as st | |
| from animation_renderer import AnimationRenderer | |
| def process_image(image, pose_detector, skeleton_generator): | |
| """ | |
| Process single image for pose detection and skeleton generation | |
| """ | |
| try: | |
| # Detect pose | |
| landmarks, annotated_image = pose_detector.detect(image) | |
| if landmarks is not None: | |
| # Generate skeleton data | |
| skeleton_data = skeleton_generator.generate_skeleton(landmarks) | |
| return annotated_image, skeleton_data | |
| return image, None | |
| except Exception as e: | |
| print(f"Error processing image: {str(e)}") | |
| return image, None | |
| def process_video(video_path, pose_detector, skeleton_generator): | |
| """ | |
| Process video for pose detection and skeleton generation | |
| with improved error handling and chunked processing | |
| """ | |
| cap = None | |
| out = None | |
| try: | |
| # Optimize video processing | |
| chunk_size = 5 | |
| buffer_size = 512 * 1024 | |
| cv2.setNumThreads(2) | |
| cv2.ocl.setUseOpenCL(False) | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| raise ValueError("Could not open video file") | |
| # Get video properties with error checking | |
| frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) or 30 | |
| # Limit dimensions for better performance | |
| target_width = 480 | |
| if frame_width > target_width: | |
| scale = target_width / frame_width | |
| frame_width = target_width | |
| frame_height = int(frame_height * scale) | |
| # Create temporary file | |
| try: | |
| temp_output = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) | |
| output_path = temp_output.name | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to create temporary file: {str(e)}") | |
| # Set lower resolution for processing | |
| cap.set(cv2.CAP_PROP_FRAME_WIDTH, 480) | |
| cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 360) | |
| if not cap.isOpened(): | |
| return None, None | |
| # Get video properties again | |
| frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
| if fps == 0: | |
| fps = 30 | |
| renderer = AnimationRenderer(fps=fps) | |
| # Create temporary file for processed video | |
| temp_output = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) | |
| output_path = temp_output.name | |
| # Initialize video writer | |
| max_dimension = 480 | |
| if frame_width > max_dimension or frame_height > max_dimension: | |
| scale = min(max_dimension / frame_width, max_dimension / frame_height) | |
| frame_width = int(frame_width * scale) | |
| frame_height = int(frame_height * scale) | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(output_path, fourcc, min(fps, 30), (frame_width, frame_height)) | |
| animation_frames = [] | |
| frame_count = 0 | |
| frame_time = 0.0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # Resize frame to reduce memory usage | |
| if frame.shape[1] > frame_width: | |
| frame = cv2.resize(frame, (frame_width, frame_height)) | |
| try: | |
| fh, fw = frame.shape[:2] | |
| if fw > 640: | |
| scale = 640 / fw | |
| frame = cv2.resize(frame, (640, int(fh * scale))) | |
| # Process frame | |
| retries = 3 | |
| landmarks, annotated_frame = None, None | |
| while retries > 0: | |
| landmarks, annotated_frame = pose_detector.detect_video_frame(frame) | |
| if landmarks is not None: | |
| break | |
| retries -= 1 | |
| # Write to output | |
| if annotated_frame is not None: | |
| out.write(annotated_frame) | |
| else: | |
| out.write(frame) | |
| if landmarks is not None: | |
| try: | |
| skeleton_data = skeleton_generator.generate_skeleton(landmarks) | |
| animation_frames.append(skeleton_data) | |
| renderer.add_keyframe(landmarks, pose_detector.pose_connections, frame_time) | |
| except Exception as e: | |
| print(f"Frame {frame_count} skeleton generation error: {str(e)}") | |
| # Если возникла ошибка, используем последний корректный кадр | |
| if animation_frames: | |
| animation_frames.append(animation_frames[-1]) | |
| else: | |
| # Нет новых landmarks, дублируем предыдущий, если есть | |
| if animation_frames: | |
| animation_frames.append(animation_frames[-1]) | |
| except Exception as e: | |
| print(f"Frame {frame_count} processing error: {str(e)}") | |
| continue | |
| frame_count += 1 | |
| frame_time = frame_count / fps | |
| if frame_count > 1000: # Safety limit | |
| break | |
| if cap is not None: | |
| cap.release() | |
| if out is not None: | |
| out.release() | |
| # Convert output video to x264 | |
| converted_output = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) | |
| os.system(f'ffmpeg -y -i {output_path} -vcodec libx264 -preset ultrafast -pix_fmt yuv420p {converted_output.name}') | |
| os.unlink(output_path) | |
| return converted_output.name, animation_frames | |
| except Exception as e: | |
| print(f"Error processing video: {str(e)}") | |
| if cap is not None: | |
| cap.release() | |
| if out is not None: | |
| out.release() | |
| return None, None | |
| def process_gif(gif_path, pose_detector, skeleton_generator): | |
| """ | |
| Process GIF for pose detection and skeleton generation. | |
| Uses Pillow to extract frames, processes each frame, | |
| and creates a temporary MP4 video with the processed frames. | |
| """ | |
| try: | |
| from PIL import Image, ImageSequence | |
| gif = Image.open(gif_path) | |
| frames = [] | |
| for frame in ImageSequence.Iterator(gif): | |
| frame = frame.convert("RGB") | |
| frame_np = np.array(frame) | |
| frame_cv = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR) | |
| frames.append(frame_cv) | |
| processed_frames = [] | |
| animation_frames = [] | |
| for frame in frames: | |
| landmarks, annotated_frame = pose_detector.detect_video_frame(frame) | |
| if annotated_frame is None: | |
| annotated_frame = frame | |
| processed_frames.append(annotated_frame) | |
| if landmarks is not None: | |
| skeleton_data = skeleton_generator.generate_skeleton(landmarks) | |
| else: | |
| # Если не удалось получить новые landmarks, берём предыдущий скелет | |
| skeleton_data = animation_frames[-1] if animation_frames else {} | |
| animation_frames.append(skeleton_data) | |
| # Собираем обработанные кадры в MP4 | |
| height, width = processed_frames[0].shape[:2] | |
| fps = 10 | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| temp_video = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) | |
| out = cv2.VideoWriter(temp_video.name, fourcc, fps, (width, height)) | |
| for frame in processed_frames: | |
| out.write(frame) | |
| out.release() | |
| return temp_video.name, animation_frames | |
| except Exception as e: | |
| print(f"Error processing GIF: {str(e)}") | |
| return None, None | |
| def process_video_upload(uploaded_file, components, processed_file, db, is_gif, col1, col2): | |
| """ | |
| Handle video/GIF file upload processing. | |
| Shows the original file in the left column and processed MP4 in the right column. | |
| """ | |
| pose_detector, skeleton_generator, animation_exporter = components | |
| # Считываем байты файла | |
| file_bytes = uploaded_file.read() | |
| # В зависимости от того, GIF это или нет, | |
| # в "Original" показываем либо st.image (для GIF), либо st.video (для обычного видео). | |
| with col1: | |
| if is_gif: | |
| st.image(file_bytes, use_column_width=True) | |
| else: | |
| st.video(file_bytes) | |
| # Сохраняем во временный файл для дальнейшей обработки | |
| temp_input = tempfile.NamedTemporaryFile( | |
| suffix=('.gif' if is_gif else '.mp4'), delete=False | |
| ) | |
| temp_input.write(file_bytes) | |
| temp_input.seek(0) | |
| video_path = temp_input.name | |
| if is_gif: | |
| processed_video_path, animation_frames = process_gif(video_path, pose_detector, skeleton_generator) | |
| else: | |
| processed_video_path, animation_frames = process_video(video_path, pose_detector, skeleton_generator) | |
| if not animation_frames: | |
| raise ValueError("No poses detected in the video/gif") | |
| # Сохраняем данные в БД | |
| save_video_data(db, processed_file.id, animation_frames) | |
| animation_data_binary = animation_exporter.export_animation(animation_frames) | |
| # Показываем результат (MP4) в правой колонке | |
| with col2: | |
| if processed_video_path: | |
| with open(processed_video_path, "rb") as f: | |
| st.video(f.read()) | |
| provide_download_button(animation_data_binary) | |
| def save_pose_data(db, file_id: int, skeleton_data: dict): | |
| from database import PoseData | |
| pose_data = PoseData(file_id=file_id, landmarks=skeleton_data) | |
| db.add(pose_data) | |
| db.commit() | |
| def save_animation_data(db, file_id: int, skeleton_data: dict): | |
| from database import AnimationData | |
| animation_data = AnimationData(file_id=file_id, skeleton_data=skeleton_data) | |
| db.add(animation_data) | |
| db.commit() | |
| def save_video_data(db, file_id: int, animation_frames: list): | |
| from database import PoseData | |
| for frame_num, frame_data in enumerate(animation_frames): | |
| # frame_data может быть пустым словарём, если не удалось получить landmarks | |
| if not frame_data: | |
| frame_data = {} | |
| pose_data = PoseData(file_id=file_id, frame_number=frame_num, landmarks=frame_data) | |
| db.add(pose_data) | |
| db.commit() | |
| def provide_download_button(animation_data_binary): | |
| st.download_button( | |
| label="Download Animation Data", | |
| data=animation_data_binary, | |
| file_name="animation.uasset", | |
| mime="application/octet-stream" | |
| ) |