Pose-Detection-App / utils.py
vertalius's picture
Update utils.py
f9f78c7 verified
import cv2
import tempfile
import numpy as np
import os
import streamlit as st
from animation_renderer import AnimationRenderer
def process_image(image, pose_detector, skeleton_generator):
"""
Process single image for pose detection and skeleton generation
"""
try:
# Detect pose
landmarks, annotated_image = pose_detector.detect(image)
if landmarks is not None:
# Generate skeleton data
skeleton_data = skeleton_generator.generate_skeleton(landmarks)
return annotated_image, skeleton_data
return image, None
except Exception as e:
print(f"Error processing image: {str(e)}")
return image, None
def process_video(video_path, pose_detector, skeleton_generator):
"""
Process video for pose detection and skeleton generation
with improved error handling and chunked processing
"""
cap = None
out = None
try:
# Optimize video processing
chunk_size = 5
buffer_size = 512 * 1024
cv2.setNumThreads(2)
cv2.ocl.setUseOpenCL(False)
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError("Could not open video file")
# Get video properties with error checking
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = int(cap.get(cv2.CAP_PROP_FPS)) or 30
# Limit dimensions for better performance
target_width = 480
if frame_width > target_width:
scale = target_width / frame_width
frame_width = target_width
frame_height = int(frame_height * scale)
# Create temporary file
try:
temp_output = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
output_path = temp_output.name
except Exception as e:
raise RuntimeError(f"Failed to create temporary file: {str(e)}")
# Set lower resolution for processing
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 480)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 360)
if not cap.isOpened():
return None, None
# Get video properties again
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
if fps == 0:
fps = 30
renderer = AnimationRenderer(fps=fps)
# Create temporary file for processed video
temp_output = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
output_path = temp_output.name
# Initialize video writer
max_dimension = 480
if frame_width > max_dimension or frame_height > max_dimension:
scale = min(max_dimension / frame_width, max_dimension / frame_height)
frame_width = int(frame_width * scale)
frame_height = int(frame_height * scale)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, min(fps, 30), (frame_width, frame_height))
animation_frames = []
frame_count = 0
frame_time = 0.0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Resize frame to reduce memory usage
if frame.shape[1] > frame_width:
frame = cv2.resize(frame, (frame_width, frame_height))
try:
fh, fw = frame.shape[:2]
if fw > 640:
scale = 640 / fw
frame = cv2.resize(frame, (640, int(fh * scale)))
# Process frame
retries = 3
landmarks, annotated_frame = None, None
while retries > 0:
landmarks, annotated_frame = pose_detector.detect_video_frame(frame)
if landmarks is not None:
break
retries -= 1
# Write to output
if annotated_frame is not None:
out.write(annotated_frame)
else:
out.write(frame)
if landmarks is not None:
try:
skeleton_data = skeleton_generator.generate_skeleton(landmarks)
animation_frames.append(skeleton_data)
renderer.add_keyframe(landmarks, pose_detector.pose_connections, frame_time)
except Exception as e:
print(f"Frame {frame_count} skeleton generation error: {str(e)}")
# Если возникла ошибка, используем последний корректный кадр
if animation_frames:
animation_frames.append(animation_frames[-1])
else:
# Нет новых landmarks, дублируем предыдущий, если есть
if animation_frames:
animation_frames.append(animation_frames[-1])
except Exception as e:
print(f"Frame {frame_count} processing error: {str(e)}")
continue
frame_count += 1
frame_time = frame_count / fps
if frame_count > 1000: # Safety limit
break
if cap is not None:
cap.release()
if out is not None:
out.release()
# Convert output video to x264
converted_output = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
os.system(f'ffmpeg -y -i {output_path} -vcodec libx264 -preset ultrafast -pix_fmt yuv420p {converted_output.name}')
os.unlink(output_path)
return converted_output.name, animation_frames
except Exception as e:
print(f"Error processing video: {str(e)}")
if cap is not None:
cap.release()
if out is not None:
out.release()
return None, None
def process_gif(gif_path, pose_detector, skeleton_generator):
"""
Process GIF for pose detection and skeleton generation.
Uses Pillow to extract frames, processes each frame,
and creates a temporary MP4 video with the processed frames.
"""
try:
from PIL import Image, ImageSequence
gif = Image.open(gif_path)
frames = []
for frame in ImageSequence.Iterator(gif):
frame = frame.convert("RGB")
frame_np = np.array(frame)
frame_cv = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR)
frames.append(frame_cv)
processed_frames = []
animation_frames = []
for frame in frames:
landmarks, annotated_frame = pose_detector.detect_video_frame(frame)
if annotated_frame is None:
annotated_frame = frame
processed_frames.append(annotated_frame)
if landmarks is not None:
skeleton_data = skeleton_generator.generate_skeleton(landmarks)
else:
# Если не удалось получить новые landmarks, берём предыдущий скелет
skeleton_data = animation_frames[-1] if animation_frames else {}
animation_frames.append(skeleton_data)
# Собираем обработанные кадры в MP4
height, width = processed_frames[0].shape[:2]
fps = 10
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
temp_video = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
out = cv2.VideoWriter(temp_video.name, fourcc, fps, (width, height))
for frame in processed_frames:
out.write(frame)
out.release()
return temp_video.name, animation_frames
except Exception as e:
print(f"Error processing GIF: {str(e)}")
return None, None
def process_video_upload(uploaded_file, components, processed_file, db, is_gif, col1, col2):
"""
Handle video/GIF file upload processing.
Shows the original file in the left column and processed MP4 in the right column.
"""
pose_detector, skeleton_generator, animation_exporter = components
# Считываем байты файла
file_bytes = uploaded_file.read()
# В зависимости от того, GIF это или нет,
# в "Original" показываем либо st.image (для GIF), либо st.video (для обычного видео).
with col1:
if is_gif:
st.image(file_bytes, use_column_width=True)
else:
st.video(file_bytes)
# Сохраняем во временный файл для дальнейшей обработки
temp_input = tempfile.NamedTemporaryFile(
suffix=('.gif' if is_gif else '.mp4'), delete=False
)
temp_input.write(file_bytes)
temp_input.seek(0)
video_path = temp_input.name
if is_gif:
processed_video_path, animation_frames = process_gif(video_path, pose_detector, skeleton_generator)
else:
processed_video_path, animation_frames = process_video(video_path, pose_detector, skeleton_generator)
if not animation_frames:
raise ValueError("No poses detected in the video/gif")
# Сохраняем данные в БД
save_video_data(db, processed_file.id, animation_frames)
animation_data_binary = animation_exporter.export_animation(animation_frames)
# Показываем результат (MP4) в правой колонке
with col2:
if processed_video_path:
with open(processed_video_path, "rb") as f:
st.video(f.read())
provide_download_button(animation_data_binary)
def save_pose_data(db, file_id: int, skeleton_data: dict):
from database import PoseData
pose_data = PoseData(file_id=file_id, landmarks=skeleton_data)
db.add(pose_data)
db.commit()
def save_animation_data(db, file_id: int, skeleton_data: dict):
from database import AnimationData
animation_data = AnimationData(file_id=file_id, skeleton_data=skeleton_data)
db.add(animation_data)
db.commit()
def save_video_data(db, file_id: int, animation_frames: list):
from database import PoseData
for frame_num, frame_data in enumerate(animation_frames):
# frame_data может быть пустым словарём, если не удалось получить landmarks
if not frame_data:
frame_data = {}
pose_data = PoseData(file_id=file_id, frame_number=frame_num, landmarks=frame_data)
db.add(pose_data)
db.commit()
def provide_download_button(animation_data_binary):
st.download_button(
label="Download Animation Data",
data=animation_data_binary,
file_name="animation.uasset",
mime="application/octet-stream"
)