File size: 8,646 Bytes
1e91e1d
 
 
 
 
 
f3457a8
0684bfb
1e91e1d
 
 
 
 
 
 
 
 
31c6ab0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e91e1d
 
 
 
 
 
 
 
 
 
 
 
 
 
159d46f
 
 
 
 
 
b7c53bf
1e91e1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31c6ab0
 
 
 
 
 
1e91e1d
 
 
 
 
 
 
 
 
 
 
 
 
 
b7c53bf
1e91e1d
 
 
 
31c6ab0
1e91e1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7c53bf
1e91e1d
 
f3457a8
 
 
 
 
 
1e91e1d
 
 
f3457a8
 
 
 
 
 
 
 
1e91e1d
 
 
 
 
 
 
b7c53bf
 
 
1e91e1d
 
0684bfb
1e91e1d
 
 
31c6ab0
1e91e1d
31c6ab0
 
1e91e1d
31c6ab0
1e91e1d
31c6ab0
1e91e1d
 
 
 
 
 
 
 
 
 
 
 
 
31c6ab0
 
1e91e1d
 
 
 
 
 
 
 
 
1035fce
1e91e1d
 
 
 
 
1035fce
1e91e1d
 
 
 
 
 
 
 
 
 
 
1035fce
 
 
 
 
1e91e1d
 
 
b7c53bf
f3457a8
b7c53bf
 
 
 
f3457a8
31c6ab0
 
 
1e91e1d
 
 
 
31c6ab0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import streamlit as st
import cv2
import numpy as np
import tempfile
from typing import Optional, Tuple
from datetime import datetime
from PIL import Image

from pose_detector import PoseDetector
from skeleton_generator import SkeletonGenerator
from animation_exporter import AnimationExporter
from utils import process_video, process_image, process_gif
from database import get_db, ProcessedFile, PoseData, AnimationData

def init_page():
    st.set_page_config(layout="wide", page_title="Pose Detection & Animation Generator")
    
    st.markdown("""
    <style>
    /* Base styling */
    .stApp {
        max-width: 100% !important;
        padding: 2rem;
        transition: background-color 0.3s;
    }
    [data-theme="dark"] .stApp {
        background-color: #1E1E1E;
        color: #FFFFFF;
    }
    .stButton>button {
        background-color: #4CAF50;
        color: white;
        border-radius: 4px;
        padding: 0.5rem 1rem;
    }
    .stDownloadButton>button {
        background-color: #008CBA;
    }
    .stProgress > div > div > div {
        background-color: #4CAF50;
    }
    .stColumn {
        padding: 0 1rem;
    }
    </style>
    """, unsafe_allow_html=True)

    st.sidebar.title("Settings")
    confidence_threshold = st.sidebar.slider(
        "Detection Confidence",
        min_value=0.0,
        max_value=1.0,
        value=0.5,
        step=0.1
    )
    
    export_format = st.sidebar.selectbox(
        "Export Format",
        options=['uasset', 'fbx', 'bvh'],
        key='export_format'
    )
    
    if "manual_correction" not in st.session_state:
        st.session_state.manual_correction = st.sidebar.checkbox("Enable Manual Corrections")
    else:
        st.session_state.manual_correction = st.sidebar.checkbox("Enable Manual Corrections", value=st.session_state.manual_correction)
        
    if st.session_state.manual_correction:
        st.sidebar.info("Интерактивная корректировка временно отключена.")
    
    st.title("Pose Detection & Animation Generator")
    return confidence_threshold

def init_components() -> Tuple[PoseDetector, SkeletonGenerator, AnimationExporter]:
    return PoseDetector(), SkeletonGenerator(), AnimationExporter()

def handle_upload(file_type: str, uploaded_file, components: Tuple, db_session) -> Optional[ProcessedFile]:
    processed_file = ProcessedFile(
        filename=uploaded_file.name,
        file_type='video' if uploaded_file.type == 'image/gif' else file_type,
        processing_status="processing"
    )
    db_session.add(processed_file)
    db_session.commit()
    db_session.refresh(processed_file)
    return processed_file

def main():
    init_page()
    components = init_components()
    
    uploaded_file = st.file_uploader(
        "Choose an image or video file (max 50MB)", 
        type=['jpg', 'jpeg', 'png', 'mp4', 'avi', 'gif']
    )
    
    if not uploaded_file:
        st.warning("Please upload a file to begin.")
        return

    db = next(get_db())
    try:
        file_type = uploaded_file.type.split('/')[0]
        is_gif = uploaded_file.type == 'image/gif'
        
        processed_file = handle_upload(file_type, uploaded_file, components, db)
        
        col1, col2 = st.columns(2)
        with col1:
            st.subheader("Original")
        with col2:
            st.subheader("Processed")

        try:
            if file_type == 'image' and not is_gif:
                process_image_upload(uploaded_file, components, processed_file, db, col1, col2)
            else:
                process_video_upload(uploaded_file, components, processed_file, db, is_gif, col1, col2)
        except Exception as e:
            st.error(f"Processing error: {str(e)}")
            processed_file.processing_status = "failed"
            db.commit()
            return

        processed_file.processing_status = "completed"
        db.commit()

    except Exception as e:
        st.error(f"An error occurred: {str(e)}")
    finally:
        db.close()

def process_image_upload(uploaded_file, components, processed_file, db, col1, col2):
    """Обработка изображения с сохранением состояния. Интерактивная коррекция отключена."""
    pose_detector, skeleton_generator, animation_exporter = components
    
    if "uploaded_image" not in st.session_state:
        file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
        image = cv2.imdecode(file_bytes, 1)
        st.session_state.uploaded_image = image
    else:
        image = st.session_state.uploaded_image

    with col1:
        st.image(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), use_column_width=True)
    
    if "original_skeleton_data" not in st.session_state:
        processed_image, skeleton_data = process_image(image, pose_detector, skeleton_generator)
        st.session_state.original_skeleton_data = skeleton_data
        st.session_state.processed_image = processed_image
    else:
        skeleton_data = st.session_state.original_skeleton_data
        processed_image = st.session_state.processed_image

    save_pose_data(db, processed_file.id, skeleton_data)
    animation_data_binary = animation_exporter.export_pose(skeleton_data)
    save_animation_data(db, processed_file.id, skeleton_data)

    with col2:
        processed_rgb = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
        st.image(processed_rgb, use_column_width=True)
        if st.session_state.get('manual_correction', False):
            st.info("Интерактивная корректировка временно отключена.")

def process_video_upload(uploaded_file, components, processed_file, db, is_gif, col1, col2):
    """Обработка видео/GIF файлов."""
    pose_detector, skeleton_generator, animation_exporter = components
    progress_bar = st.progress(0)
    
    with tempfile.NamedTemporaryFile() as tfile:
        tfile.write(uploaded_file.read())
        tfile.seek(0)
        
        with col1:
            st.video(tfile.read())

        video_path = tfile.name
        if is_gif:
            processed_video_path, animation_frames = process_gif(video_path, pose_detector, skeleton_generator)
        else:
            processed_video_path, animation_frames = process_video(video_path, pose_detector, skeleton_generator)

        if not animation_frames:
            raise ValueError("No poses detected in the video/gif")

        save_video_data(db, processed_file.id, animation_frames)
        animation_data_binary = animation_exporter.export_animation(animation_frames)

        with col2:
            if processed_video_path:
                with open(processed_video_path, "rb") as f:
                    st.video(f.read())

        provide_download_button(animation_data_binary)

def save_pose_data(db, file_id: int, skeleton_data: dict):
    pose_data = PoseData(file_id=file_id, landmarks=skeleton_data)
    db.add(pose_data)
    db.commit()

def save_animation_data(db, file_id: int, skeleton_data: dict):
    animation_data = AnimationData(file_id=file_id, skeleton_data=skeleton_data)
    db.add(animation_data)
    db.commit()

def save_video_data(db, file_id: int, animation_frames: list):
    for frame_num, frame_data in enumerate(animation_frames):
        pose_data = PoseData(file_id=file_id, frame_number=frame_num, landmarks=frame_data)
        db.add(pose_data)
    db.commit()

def provide_download_button(animation_data_binary):
    st.download_button(
        label="Download Animation Data",
        data=animation_data_binary,
        file_name="animation.uasset",
        mime="application/octet-stream"
    )

def save_corrected_pose(db, file_id: int, joints: dict):
    pose_data = PoseData(file_id=file_id, landmarks=joints, corrected_landmarks=joints, is_corrected=True)
    db.add(pose_data)
    db.commit()

def show_instructions():
    with st.expander("Instructions"):
        st.markdown("""
        1. Upload an image or video using the file uploader.
        2. Wait for processing to complete.
        3. The **Detection Confidence** slider in the sidebar sets the minimum confidence threshold for detecting a pose.
           A higher value (e.g., 0.7) means that only detections with high certainty are considered, which may improve accuracy.
        4. The processed image/video will be displayed in the "Processed" panel.
        5. Download animation data.
        
        Supported formats:
        - Images: JPG, PNG
        - Videos: MP4, GIF
        """)

if __name__ == "__main__":
    main()
    show_instructions()