vertalius commited on
Commit
31c6ab0
·
verified ·
1 Parent(s): d9c1c4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -105
app.py CHANGED
@@ -12,24 +12,51 @@ from utils import process_video, process_image, process_gif
12
  from database import get_db, ProcessedFile, PoseData, AnimationData
13
 
14
  def init_page():
15
- """Initialize Streamlit page configuration and styling."""
16
  st.set_page_config(layout="wide", page_title="Pose Detection & Animation Generator")
17
- with open('static/style.css') as f:
18
- st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
19
 
20
- # Theme selection
21
- theme = st.sidebar.selectbox("Theme", ["Light", "Dark"], key="theme")
22
- if theme == "Dark":
23
- st.markdown("""
24
- <style>
25
- .stApp { background-color: #1E1E1E; color: #FFFFFF; }
26
- </style>
27
- """, unsafe_allow_html=True)
 
28
 
29
- # Detection settings
30
- st.sidebar.title("Settings")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
 
 
 
 
 
 
 
32
  # Detection settings
 
33
  confidence_threshold = st.sidebar.slider(
34
  "Detection Confidence",
35
  min_value=0.0,
@@ -50,11 +77,6 @@ def init_page():
50
  if enable_corrections:
51
  st.sidebar.info("Click on landmarks in the preview to adjust their positions")
52
 
53
- # Custom skeleton mapping
54
- show_mapping = st.sidebar.expander("Skeleton Mapping")
55
- with show_mapping:
56
- st.text_area("Custom Mapping (JSON)", value="{}", key="custom_mapping")
57
-
58
  st.title("Pose Detection & Animation Generator")
59
  return confidence_threshold
60
 
@@ -64,8 +86,6 @@ def init_components() -> Tuple[PoseDetector, SkeletonGenerator, AnimationExporte
64
 
65
  def handle_upload(file_type: str, uploaded_file, components: Tuple, db_session) -> Optional[ProcessedFile]:
66
  """Process uploaded file and store results in database."""
67
- pose_detector, skeleton_generator, animation_exporter = components
68
-
69
  processed_file = ProcessedFile(
70
  filename=uploaded_file.name,
71
  file_type='video' if uploaded_file.type == 'image/gif' else file_type,
@@ -74,50 +94,20 @@ def handle_upload(file_type: str, uploaded_file, components: Tuple, db_session)
74
  db_session.add(processed_file)
75
  db_session.commit()
76
  db_session.refresh(processed_file)
77
-
78
  return processed_file
79
 
80
  def main():
81
  init_page()
82
  components = init_components()
83
 
84
- try:
85
- uploaded_file = st.file_uploader(
86
- "Choose an image or video file (max 50MB)",
87
- type=['jpg', 'jpeg', 'png', 'mp4', 'avi', 'gif']
88
- )
89
- if uploaded_file is not None:
90
- st.cache_data.clear() # Clear cache to prevent stale data
91
- except Exception as e:
92
- st.error("Network error occurred. Please try uploading again.")
93
- return
94
-
95
- if uploaded_file is None:
96
  st.warning("Please upload a file to begin.")
97
  return
98
-
99
- if uploaded_file.type == 'video/mp4':
100
- try:
101
- st.info("Processing video... This may take a moment.")
102
- file_size = len(uploaded_file.getvalue()) / (1024 * 1024) # Size in MB
103
- if file_size > 50:
104
- st.error("Video file size must be under 50MB. Please upload a smaller file.")
105
- return
106
-
107
- # Validate video file
108
- with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tfile:
109
- tfile.write(uploaded_file.getvalue())
110
- cap = cv2.VideoCapture(tfile.name)
111
- if not cap.isOpened():
112
- st.error("Invalid video file. Please try a different file.")
113
- return
114
- cap.release()
115
- except Exception as e:
116
- st.error(f"Error processing video: {str(e)}")
117
- return
118
-
119
- if uploaded_file is None:
120
- return
121
 
122
  db = next(get_db())
123
  try:
@@ -135,7 +125,7 @@ def main():
135
  try:
136
  if file_type == 'image' and not is_gif:
137
  process_image_upload(uploaded_file, components, processed_file, db, col1, col2)
138
- elif file_type == 'video' or is_gif:
139
  process_video_upload(uploaded_file, components, processed_file, db, is_gif, col1, col2)
140
  except Exception as e:
141
  st.error(f"Processing error: {str(e)}")
@@ -171,39 +161,33 @@ def process_image_upload(uploaded_file, components, processed_file, db, col1, co
171
  save_animation_data(db, processed_file.id, skeleton_data)
172
 
173
  with col2:
174
- # Create a canvas for manual corrections
175
- canvas_container = st.empty()
176
  processed_rgb = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
177
-
178
- # Add manual correction controls
 
179
  if st.button("Enable Manual Correction"):
180
  st.session_state.manual_correction = True
181
- st.session_state.current_landmarks = skeleton_data
182
 
183
  if st.session_state.get('manual_correction', False):
184
- # Display current joint positions
185
  joints = st.session_state.current_landmarks
186
-
187
  selected_joint = st.selectbox("Select Joint to Adjust", list(joints.keys()))
188
 
189
- col1, col2 = st.columns(2)
190
- with col1:
191
  x_pos = st.slider("X Position", 0.0, 1.0, float(joints[selected_joint]['position'][0]), 0.01)
192
- with col2:
193
  y_pos = st.slider("Y Position", 0.0, 1.0, float(joints[selected_joint]['position'][1]), 0.01)
194
 
195
  if st.button("Apply Changes"):
196
- joints[selected_joint]['position'][0] = x_pos
197
- joints[selected_joint]['position'][1] = y_pos
198
- st.session_state.current_landmarks = joints
199
  processed_image = pose_detector.draw_corrected_pose(image, joints)
200
  processed_rgb = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
 
201
 
202
  if st.button("Save Corrections"):
203
- save_corrected_pose(db, processed_file.id, st.session_state.current_landmarks)
204
  st.success("Corrections saved successfully!")
205
-
206
- canvas_container.image(processed_rgb, use_column_width=True)
207
 
208
  provide_download_button(animation_data_binary)
209
 
@@ -211,15 +195,15 @@ def process_video_upload(uploaded_file, components, processed_file, db, is_gif,
211
  """Handle video/GIF file upload processing."""
212
  pose_detector, skeleton_generator, animation_exporter = components
213
  progress_bar = st.progress(0)
214
- status_text = st.empty()
215
 
216
- with tempfile.NamedTemporaryFile(delete=False, suffix='.gif' if is_gif else '.mp4') as tfile:
217
  tfile.write(uploaded_file.read())
218
- video_path = tfile.name
219
-
220
  with col1:
221
- st.video(video_path)
222
 
 
223
  if is_gif:
224
  processed_video_path, animation_frames = process_gif(video_path, pose_detector, skeleton_generator)
225
  else:
@@ -233,20 +217,17 @@ def process_video_upload(uploaded_file, components, processed_file, db, is_gif,
233
 
234
  with col2:
235
  if processed_video_path:
236
- st.video(processed_video_path)
 
237
 
238
  provide_download_button(animation_data_binary)
239
-
240
- cleanup_temp_files(video_path, processed_video_path)
241
 
242
  def save_pose_data(db, file_id: int, skeleton_data: dict):
243
- """Save pose data to database."""
244
  pose_data = PoseData(file_id=file_id, landmarks=skeleton_data)
245
  db.add(pose_data)
246
  db.commit()
247
 
248
  def save_animation_data(db, file_id: int, skeleton_data: dict):
249
- """Save animation data to database."""
250
  animation_data = AnimationData(
251
  file_id=file_id,
252
  skeleton_data=skeleton_data
@@ -255,7 +236,6 @@ def save_animation_data(db, file_id: int, skeleton_data: dict):
255
  db.commit()
256
 
257
  def save_video_data(db, file_id: int, animation_frames: list):
258
- """Save video frame data to database."""
259
  for frame_num, frame_data in enumerate(animation_frames):
260
  pose_data = PoseData(
261
  file_id=file_id,
@@ -266,7 +246,6 @@ def save_video_data(db, file_id: int, animation_frames: list):
266
  db.commit()
267
 
268
  def provide_download_button(animation_data_binary):
269
- """Provide download button for animation data."""
270
  st.download_button(
271
  label="Download Animation Data",
272
  data=animation_data_binary,
@@ -274,30 +253,19 @@ def provide_download_button(animation_data_binary):
274
  mime="application/octet-stream"
275
  )
276
 
277
- def cleanup_temp_files(*file_paths):
278
- """Clean up temporary files."""
279
- for file_path in file_paths:
280
- if file_path:
281
- try:
282
- import os
283
- os.unlink(file_path)
284
- except Exception:
285
- pass
286
-
287
  def show_instructions():
288
- """Show usage instructions."""
289
  with st.expander("Instructions"):
290
  st.markdown("""
291
- 1. Upload an image or video file using the file uploader above
292
- 2. Wait for the pose detection and skeleton generation to complete
293
- 3. Preview the results in the right column
294
- 4. Download the animation data for use in Unreal Engine
295
-
296
- Supported file formats:
297
- - Images: JPG, JPEG, PNG
298
- - Videos: MP4, AVI, GIF
299
  """)
300
 
301
  if __name__ == "__main__":
302
  main()
303
- show_instructions()
 
12
  from database import get_db, ProcessedFile, PoseData, AnimationData
13
 
14
  def init_page():
15
+ """Initialize Streamlit page configuration with embedded styling."""
16
  st.set_page_config(layout="wide", page_title="Pose Detection & Animation Generator")
 
 
17
 
18
+ # Embedded CSS styles
19
+ st.markdown("""
20
+ <style>
21
+ /* Base styling */
22
+ .stApp {
23
+ max-width: 100% !important;
24
+ padding: 2rem;
25
+ transition: background-color 0.3s;
26
+ }
27
 
28
+ /* Dark theme */
29
+ [data-theme="dark"] .stApp {
30
+ background-color: #1E1E1E;
31
+ color: #FFFFFF;
32
+ }
33
+
34
+ /* Widget styling */
35
+ .stButton>button {
36
+ background-color: #4CAF50;
37
+ color: white;
38
+ border-radius: 4px;
39
+ padding: 0.5rem 1rem;
40
+ }
41
+
42
+ .stDownloadButton>button {
43
+ background-color: #008CBA;
44
+ }
45
+
46
+ /* Progress bars */
47
+ .stProgress > div > div > div {
48
+ background-color: #4CAF50;
49
+ }
50
 
51
+ /* Columns spacing */
52
+ .stColumn {
53
+ padding: 0 1rem;
54
+ }
55
+ </style>
56
+ """, unsafe_allow_html=True)
57
+
58
  # Detection settings
59
+ st.sidebar.title("Settings")
60
  confidence_threshold = st.sidebar.slider(
61
  "Detection Confidence",
62
  min_value=0.0,
 
77
  if enable_corrections:
78
  st.sidebar.info("Click on landmarks in the preview to adjust their positions")
79
 
 
 
 
 
 
80
  st.title("Pose Detection & Animation Generator")
81
  return confidence_threshold
82
 
 
86
 
87
  def handle_upload(file_type: str, uploaded_file, components: Tuple, db_session) -> Optional[ProcessedFile]:
88
  """Process uploaded file and store results in database."""
 
 
89
  processed_file = ProcessedFile(
90
  filename=uploaded_file.name,
91
  file_type='video' if uploaded_file.type == 'image/gif' else file_type,
 
94
  db_session.add(processed_file)
95
  db_session.commit()
96
  db_session.refresh(processed_file)
 
97
  return processed_file
98
 
99
  def main():
100
  init_page()
101
  components = init_components()
102
 
103
+ uploaded_file = st.file_uploader(
104
+ "Choose an image or video file (max 50MB)",
105
+ type=['jpg', 'jpeg', 'png', 'mp4', 'avi', 'gif']
106
+ )
107
+
108
+ if not uploaded_file:
 
 
 
 
 
 
109
  st.warning("Please upload a file to begin.")
110
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  db = next(get_db())
113
  try:
 
125
  try:
126
  if file_type == 'image' and not is_gif:
127
  process_image_upload(uploaded_file, components, processed_file, db, col1, col2)
128
+ else:
129
  process_video_upload(uploaded_file, components, processed_file, db, is_gif, col1, col2)
130
  except Exception as e:
131
  st.error(f"Processing error: {str(e)}")
 
161
  save_animation_data(db, processed_file.id, skeleton_data)
162
 
163
  with col2:
 
 
164
  processed_rgb = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
165
+ canvas_container = st.empty()
166
+ canvas_container.image(processed_rgb, use_column_width=True)
167
+
168
  if st.button("Enable Manual Correction"):
169
  st.session_state.manual_correction = True
170
+ st.session_state.current_landmarks = skeleton_data.copy()
171
 
172
  if st.session_state.get('manual_correction', False):
 
173
  joints = st.session_state.current_landmarks
 
174
  selected_joint = st.selectbox("Select Joint to Adjust", list(joints.keys()))
175
 
176
+ col_x, col_y = st.columns(2)
177
+ with col_x:
178
  x_pos = st.slider("X Position", 0.0, 1.0, float(joints[selected_joint]['position'][0]), 0.01)
179
+ with col_y:
180
  y_pos = st.slider("Y Position", 0.0, 1.0, float(joints[selected_joint]['position'][1]), 0.01)
181
 
182
  if st.button("Apply Changes"):
183
+ joints[selected_joint]['position'] = [x_pos, y_pos]
 
 
184
  processed_image = pose_detector.draw_corrected_pose(image, joints)
185
  processed_rgb = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB)
186
+ canvas_container.image(processed_rgb, use_column_width=True)
187
 
188
  if st.button("Save Corrections"):
189
+ save_corrected_pose(db, processed_file.id, joints)
190
  st.success("Corrections saved successfully!")
 
 
191
 
192
  provide_download_button(animation_data_binary)
193
 
 
195
  """Handle video/GIF file upload processing."""
196
  pose_detector, skeleton_generator, animation_exporter = components
197
  progress_bar = st.progress(0)
 
198
 
199
+ with tempfile.NamedTemporaryFile() as tfile:
200
  tfile.write(uploaded_file.read())
201
+ tfile.seek(0)
202
+
203
  with col1:
204
+ st.video(tfile.read())
205
 
206
+ video_path = tfile.name
207
  if is_gif:
208
  processed_video_path, animation_frames = process_gif(video_path, pose_detector, skeleton_generator)
209
  else:
 
217
 
218
  with col2:
219
  if processed_video_path:
220
+ with open(processed_video_path, "rb") as f:
221
+ st.video(f.read())
222
 
223
  provide_download_button(animation_data_binary)
 
 
224
 
225
  def save_pose_data(db, file_id: int, skeleton_data: dict):
 
226
  pose_data = PoseData(file_id=file_id, landmarks=skeleton_data)
227
  db.add(pose_data)
228
  db.commit()
229
 
230
  def save_animation_data(db, file_id: int, skeleton_data: dict):
 
231
  animation_data = AnimationData(
232
  file_id=file_id,
233
  skeleton_data=skeleton_data
 
236
  db.commit()
237
 
238
  def save_video_data(db, file_id: int, animation_frames: list):
 
239
  for frame_num, frame_data in enumerate(animation_frames):
240
  pose_data = PoseData(
241
  file_id=file_id,
 
246
  db.commit()
247
 
248
  def provide_download_button(animation_data_binary):
 
249
  st.download_button(
250
  label="Download Animation Data",
251
  data=animation_data_binary,
 
253
  mime="application/octet-stream"
254
  )
255
 
 
 
 
 
 
 
 
 
 
 
256
  def show_instructions():
 
257
  with st.expander("Instructions"):
258
  st.markdown("""
259
+ 1. Upload an image/video using the file uploader
260
+ 2. Wait for processing to complete
261
+ 3. Preview results in the right column
262
+ 4. Download animation data
263
+
264
+ Supported formats:
265
+ - Images: JPG, PNG
266
+ - Videos: MP4, GIF
267
  """)
268
 
269
  if __name__ == "__main__":
270
  main()
271
+ show_instructions()