vertalius's picture
Update app.py
b7c53bf verified
import streamlit as st
import cv2
import numpy as np
import tempfile
from typing import Optional, Tuple
from datetime import datetime
from PIL import Image
from pose_detector import PoseDetector
from skeleton_generator import SkeletonGenerator
from animation_exporter import AnimationExporter
from utils import process_video, process_image, process_gif
from database import get_db, ProcessedFile, PoseData, AnimationData
def init_page():
st.set_page_config(layout="wide", page_title="Pose Detection & Animation Generator")
st.markdown("""
<style>
/* Base styling */
.stApp {
max-width: 100% !important;
padding: 2rem;
transition: background-color 0.3s;
}
[data-theme="dark"] .stApp {
background-color: #1E1E1E;
color: #FFFFFF;
}
.stButton>button {
background-color: #4CAF50;
color: white;
border-radius: 4px;
padding: 0.5rem 1rem;
}
.stDownloadButton>button {
background-color: #008CBA;
}
.stProgress > div > div > div {
background-color: #4CAF50;
}
.stColumn {
padding: 0 1rem;
}
</style>
""", unsafe_allow_html=True)
st.sidebar.title("Settings")
confidence_threshold = st.sidebar.slider(
"Detection Confidence",
min_value=0.0,
max_value=1.0,
value=0.5,
step=0.1
)
export_format = st.sidebar.selectbox(
"Export Format",
options=['uasset', 'fbx', 'bvh'],
key='export_format'
)
if "manual_correction" not in st.session_state:
st.session_state.manual_correction = st.sidebar.checkbox("Enable Manual Corrections")
else:
st.session_state.manual_correction = st.sidebar.checkbox("Enable Manual Corrections", value=st.session_state.manual_correction)
if st.session_state.manual_correction:
st.sidebar.info("Интерактивная корректировка временно отключена.")
st.title("Pose Detection & Animation Generator")
return confidence_threshold
def init_components() -> Tuple[PoseDetector, SkeletonGenerator, AnimationExporter]:
return PoseDetector(), SkeletonGenerator(), AnimationExporter()
def handle_upload(file_type: str, uploaded_file, components: Tuple, db_session) -> Optional[ProcessedFile]:
processed_file = ProcessedFile(
filename=uploaded_file.name,
file_type='video' if uploaded_file.type == 'image/gif' else file_type,
processing_status="processing"
)
db_session.add(processed_file)
db_session.commit()
db_session.refresh(processed_file)
return processed_file
def main():
init_page()
components = init_components()
uploaded_file = st.file_uploader(
"Choose an image or video file (max 50MB)",
type=['jpg', 'jpeg', 'png', 'mp4', 'avi', 'gif']
)
if not uploaded_file:
st.warning("Please upload a file to begin.")
return
db = next(get_db())
try:
file_type = uploaded_file.type.split('/')[0]
is_gif = uploaded_file.type == 'image/gif'
processed_file = handle_upload(file_type, uploaded_file, components, db)
col1, col2 = st.columns(2)
with col1:
st.subheader("Original")
with col2:
st.subheader("Processed")
try:
if file_type == 'image' and not is_gif:
process_image_upload(uploaded_file, components, processed_file, db, col1, col2)
else:
process_video_upload(uploaded_file, components, processed_file, db, is_gif, col1, col2)
except Exception as e:
st.error(f"Processing error: {str(e)}")
processed_file.processing_status = "failed"
db.commit()
return
processed_file.processing_status = "completed"
db.commit()
except Exception as e:
st.error(f"An error occurred: {str(e)}")
finally:
db.close()
def process_image_upload(uploaded_file, components, processed_file, db, col1, col2):
"""Обработка изображения с сохранением состояния. Интерактивная коррекция отключена."""
pose_detector, skeleton_generator, animation_exporter = components
if "uploaded_image" not in st.session_state:
file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
image = cv2.imdecode(file_bytes, 1)
st.session_state.uploaded_image = image
else:
image = st.session_state.uploaded_image
with col1:
st.image(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), use_column_width=True)
if "original_skeleton_data" not in st.session_state:
processed_image, skeleton_data = process_image(image, pose_detector, skeleton_generator)
st.session_state.original_skeleton_data = skeleton_data
st.session_state.processed_image = processed_image
else:
skeleton_data = st.session_state.original_skeleton_data
processed_image = st.session_state.processed_image
save_pose_data(db, processed_file.id, skeleton_data)
animation_data_binary = animation_exporter.export_pose(skeleton_data)
save_animation_data(db, processed_file.id, skeleton_data)
with col2:
processed_rgb = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
st.image(processed_rgb, use_column_width=True)
if st.session_state.get('manual_correction', False):
st.info("Интерактивная корректировка временно отключена.")
def process_video_upload(uploaded_file, components, processed_file, db, is_gif, col1, col2):
"""Обработка видео/GIF файлов."""
pose_detector, skeleton_generator, animation_exporter = components
progress_bar = st.progress(0)
with tempfile.NamedTemporaryFile() as tfile:
tfile.write(uploaded_file.read())
tfile.seek(0)
with col1:
st.video(tfile.read())
video_path = tfile.name
if is_gif:
processed_video_path, animation_frames = process_gif(video_path, pose_detector, skeleton_generator)
else:
processed_video_path, animation_frames = process_video(video_path, pose_detector, skeleton_generator)
if not animation_frames:
raise ValueError("No poses detected in the video/gif")
save_video_data(db, processed_file.id, animation_frames)
animation_data_binary = animation_exporter.export_animation(animation_frames)
with col2:
if processed_video_path:
with open(processed_video_path, "rb") as f:
st.video(f.read())
provide_download_button(animation_data_binary)
def save_pose_data(db, file_id: int, skeleton_data: dict):
pose_data = PoseData(file_id=file_id, landmarks=skeleton_data)
db.add(pose_data)
db.commit()
def save_animation_data(db, file_id: int, skeleton_data: dict):
animation_data = AnimationData(file_id=file_id, skeleton_data=skeleton_data)
db.add(animation_data)
db.commit()
def save_video_data(db, file_id: int, animation_frames: list):
for frame_num, frame_data in enumerate(animation_frames):
pose_data = PoseData(file_id=file_id, frame_number=frame_num, landmarks=frame_data)
db.add(pose_data)
db.commit()
def provide_download_button(animation_data_binary):
st.download_button(
label="Download Animation Data",
data=animation_data_binary,
file_name="animation.uasset",
mime="application/octet-stream"
)
def save_corrected_pose(db, file_id: int, joints: dict):
pose_data = PoseData(file_id=file_id, landmarks=joints, corrected_landmarks=joints, is_corrected=True)
db.add(pose_data)
db.commit()
def show_instructions():
with st.expander("Instructions"):
st.markdown("""
1. Upload an image or video using the file uploader.
2. Wait for processing to complete.
3. The **Detection Confidence** slider in the sidebar sets the minimum confidence threshold for detecting a pose.
A higher value (e.g., 0.7) means that only detections with high certainty are considered, which may improve accuracy.
4. The processed image/video will be displayed in the "Processed" panel.
5. Download animation data.
Supported formats:
- Images: JPG, PNG
- Videos: MP4, GIF
""")
if __name__ == "__main__":
main()
show_instructions()