import streamlit as st import cv2 import tempfile import os import numpy as np import time from src.core.pose import PoseEstimator from src.core.compare import BiomechanicsMatcher from src.core.overlay import create_combined_overlay from src.coach import get_coaching_feedback from src.core.cache_utils import save_to_cache, load_from_cache st.set_page_config(page_title="AI Taekwondo Coach", layout="wide") # Custom CSS for Premium Design st.markdown(""" """, unsafe_allow_html=True) st.title("🥋 AI Sports Coach - Taekwondo") st.markdown("### Phân tích và so sánh kỹ thuật bài quyền") # Sidebar Configuration st.sidebar.header("⚙️ Cấu hình AI") model_choice = st.sidebar.radio( "Độ phức tạp Model", ["Lite (Nhanh)", "Full (Cân bằng)", "Heavy (Chính xác)"], index=1, help="Lite phù hợp cho Mobile, Heavy yêu cầu cấu hình mạnh hơn." ) model_type_map = {"Lite (Nhanh)": "lite", "Full (Cân bằng)": "full", "Heavy (Chính xác)": "heavy"} model_type = model_type_map[model_choice] skip_step = st.sidebar.slider("Nhảy khung hình (Skip Frames)", 1, 5, 2, help="Giá trị cao giúp xử lý nhanh hơn nhưng biểu đồ có thể bớt mượt.") resize_width = st.sidebar.slider("Độ phân giải xử lý (Width)", 240, 720, 480, step=80, help="Hạ độ phân giải giúp MediaPipe chạy nhanh hơn.") def cleanup_old_videos(output_dir="outputs", max_age_seconds=300): """Xóa các video cũ hơn 5 phút (300 giây).""" if not os.path.exists(output_dir): return now = time.time() for f in os.listdir(output_dir): f_path = os.path.join(output_dir, f) if os.path.isfile(f_path) and now - os.path.getmtime(f_path) > max_age_seconds: try: os.remove(f_path) except Exception as e: print(f"Error cleaning up {f}: {e}") def process_video(video_path, progress_text, model_type="full", skip_step=1, resize_width=480): cap = cv2.VideoCapture(video_path) estimator = PoseEstimator(model_type=model_type, resize_width=resize_width) landmarks_seq = [] frames = [] frame_count_total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) progress_bar = st.progress(0, text=progress_text) current_frame = 0 while cap.isOpened(): ret, frame = cap.read() if not ret: break # Optimization: Frame Skipping if current_frame % skip_step == 0: results = estimator.process_frame(frame) landmarks = estimator.extract_landmarks(results) landmarks_seq.append({"frame": current_frame, "landmarks": landmarks}) # Store metadata for visualization (every 2nd processed frame) if (current_frame // skip_step) % 2 == 0: frames.append(frame) current_frame += 1 if current_frame % (skip_step * 5) == 0: progress_msg = f"{progress_text}: {int((current_frame/frame_count_total)*100)}%" progress_bar.progress(min(current_frame / frame_count_total, 1.0), text=progress_msg) cap.release() estimator.close() progress_bar.empty() return landmarks_seq, frames # UI Layout: Two columns for uploads col1, col2 = st.columns(2) with col1: st.markdown('
', unsafe_allow_html=True) st.subheader("1. Video Mẫu (Chuyên gia)") ref_file = st.file_uploader("Upload video mẫu", type=["mp4", "mov"], key="ref") st.markdown('
', unsafe_allow_html=True) with col2: st.markdown('
', unsafe_allow_html=True) st.subheader("2. Video Của Bạn") user_file = st.file_uploader("Upload video thực hiện", type=["mp4", "mov"], key="user") st.markdown('
', unsafe_allow_html=True) if ref_file and user_file: if st.button("🚀 BẮT ĐẦU PHÂN TÍCH"): # Save temp files with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_ref, \ tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_user: tmp_ref.write(ref_file.read()) tmp_user.write(user_file.read()) path_ref = tmp_ref.name path_user = tmp_user.name try: # 1. Process Reference Video (with Caching) with st.status("Đang chuẩn bị video mẫu...") as status: # Try loading from cache first ref_landmarks = load_from_cache(path_ref) if ref_landmarks: status.update(label="⚡ Đã tải video mẫu từ Cache!", state="complete") else: status.update(label="Đang phân tích video mẫu (lần đầu)...") # Use "heavy" model for reference by default for best quality ref_landmarks, _ = process_video(path_ref, "Đang trích xuất khung xương mẫu", model_type="heavy", skip_step=1, resize_width=720) save_to_cache(path_ref, ref_landmarks) status.update(label="Xong video mẫu!", state="complete") # 2. Process User Video (Optimized) with st.status("Đang phân tích video của bạn...") as status: user_landmarks, user_frames = process_video(path_user, "Đang trích xuất khung xương của bạn", model_type=model_type, skip_step=skip_step, resize_width=resize_width) status.update(label="Xong video người dùng!", state="complete") # 3. Perform Biomechanics Comparison st.markdown("---") st.subheader("📊 Kết Quả Phân Tích") matcher = BiomechanicsMatcher(ref_landmarks) path, errors = matcher.compare(user_landmarks) if path: score = matcher.get_summary_score(errors) # Metrics Dashboard st.markdown("### 📈 Chỉ Số Hiệu Suất") res_col1, res_col2, res_col3 = st.columns(3) res_col1.metric("Độ chính xác", f"{int(score)}%", delta=f"{int(score-70)}% so với mục tiêu") res_col2.metric("Nhịp điệu", "Khá tốt", delta="Đồng bộ cao") major_errors = len([e for e in errors if e["total"] > 90.0]) res_col3.metric("Lỗi cần sửa", major_errors, delta_color="inverse") # Detailed Analysis Chart st.markdown("#### Biểu đồ sai lệch theo thời gian") error_values = [e["total"] for e in errors] st.line_chart(error_values) # Body Part Breakdown st.markdown("#### Phân tích chi tiết từng bộ phận") # Calculate avg error per feature (8 angles + 2 hands) all_diffs = [] for e in errors: all_diffs.append(np.concatenate([e["angles"], e["hands"]])) avg_feats = np.mean(all_diffs, axis=0) from src.core.overlay import FEATURE_LABELS breakdown_cols = st.columns(5) for i in range(10): with breakdown_cols[i % 5]: err_val = int(avg_feats[i]) status = "✅" if err_val < 30 else "⚠️" if err_val < 60 else "❌" st.write(f"{status} **{FEATURE_LABELS[i]}**") st.caption(f"Lệch: {err_val}°") # 4. Generate Result Video st.markdown("---") st.markdown("### 🎥 Video Phân Tích Chi Tiết") # Cleanup and prepare filename cleanup_old_videos() timestamp = int(time.time()) # Sanitize filenames (remove extension and spaces) ref_name = os.path.splitext(ref_file.name)[0].replace(" ", "_") user_name = os.path.splitext(user_file.name)[0].replace(" ", "_") output_video_path = f"outputs/{ref_name}_{user_name}_{score:.2f}_{timestamp}.webm" from src.core.overlay import generate_result_video with st.spinner("Đang render video phân tích (High-quality WebM)..."): generate_result_video(path_user, user_landmarks, path, errors, output_video_path) if os.path.exists(output_video_path) and os.path.getsize(output_video_path) > 0: with open(output_video_path, 'rb') as v_file: video_bytes = v_file.read() st.video(video_bytes) st.success("Video đã được AI xử lý. Các vị trí khoanh đỏ là nơi bạn cần tinh chỉnh góc độ.") else: st.error("Lỗi khi tạo video kết quả.") # 5. AI Coach Insights st.markdown("---") st.markdown("### 🤖 Lời Khuyên Từ Huấn Luyện Viên AI") with st.chat_message("assistant", avatar="🥋"): with st.spinner("AI đang phân tích các lỗi sai và chuẩn bị lời khuyên..."): feedback = get_coaching_feedback(score, avg_feats, FEATURE_LABELS) st.markdown(feedback) st.balloons() else: st.error("Không thể đồng bộ hóa hai video. Hãy đảm bảo cả hai đều quay rõ người.") finally: # Cleanup if os.path.exists(path_ref): os.unlink(path_ref) if os.path.exists(path_user): os.unlink(path_user) else: st.info("Vui lòng tải lên cả hai video để bắt đầu so sánh.")