AI-Coach / temp.py
anhlehong
feat/enhance
1c58706
import streamlit as st
import cv2
import tempfile
import os
import numpy as np
import time
from src.core.pose import PoseEstimator
from src.core.compare import BiomechanicsMatcher
from src.core.overlay import create_combined_overlay
from src.coach import get_coaching_feedback
from src.core.cache_utils import save_to_cache, load_from_cache
st.set_page_config(page_title="AI Taekwondo Coach", layout="wide")
# Custom CSS for Premium Design
st.markdown("""
<style>
.main {
background-color: #0e1117;
color: white;
}
.stButton>button {
width: 100%;
border-radius: 20px;
height: 3.5em;
background: linear-gradient(45deg, #ff4b4b, #ff8f8f);
color: white;
font-weight: bold;
border: none;
transition: 0.3s;
}
.stButton>button:hover {
transform: scale(1.02);
box-shadow: 0 10px 20px rgba(255, 75, 75, 0.3);
}
.upload-card {
background-color: #1e2130;
padding: 20px;
border-radius: 15px;
border: 1px solid #3e4259;
}
</style>
""", unsafe_allow_html=True)
st.title("🥋 AI Sports Coach - Taekwondo")
st.markdown("### Phân tích và so sánh kỹ thuật bài quyền")
# Sidebar Configuration
st.sidebar.header("⚙️ Cấu hình AI")
model_choice = st.sidebar.radio(
"Độ phức tạp Model",
["Lite (Nhanh)", "Full (Cân bằng)", "Heavy (Chính xác)"],
index=1,
help="Lite phù hợp cho Mobile, Heavy yêu cầu cấu hình mạnh hơn."
)
model_type_map = {"Lite (Nhanh)": "lite", "Full (Cân bằng)": "full", "Heavy (Chính xác)": "heavy"}
model_type = model_type_map[model_choice]
skip_step = st.sidebar.slider("Nhảy khung hình (Skip Frames)", 1, 5, 2, help="Giá trị cao giúp xử lý nhanh hơn nhưng biểu đồ có thể bớt mượt.")
resize_width = st.sidebar.slider("Độ phân giải xử lý (Width)", 240, 720, 480, step=80, help="Hạ độ phân giải giúp MediaPipe chạy nhanh hơn.")
def cleanup_old_videos(output_dir="outputs", max_age_seconds=300):
"""Xóa các video cũ hơn 5 phút (300 giây)."""
if not os.path.exists(output_dir):
return
now = time.time()
for f in os.listdir(output_dir):
f_path = os.path.join(output_dir, f)
if os.path.isfile(f_path) and now - os.path.getmtime(f_path) > max_age_seconds:
try:
os.remove(f_path)
except Exception as e:
print(f"Error cleaning up {f}: {e}")
def process_video(video_path, progress_text, model_type="full", skip_step=1, resize_width=480):
cap = cv2.VideoCapture(video_path)
estimator = PoseEstimator(model_type=model_type, resize_width=resize_width)
landmarks_seq = []
frames = []
frame_count_total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
progress_bar = st.progress(0, text=progress_text)
current_frame = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Optimization: Frame Skipping
if current_frame % skip_step == 0:
results = estimator.process_frame(frame)
landmarks = estimator.extract_landmarks(results)
landmarks_seq.append({"frame": current_frame, "landmarks": landmarks})
# Store metadata for visualization (every 2nd processed frame)
if (current_frame // skip_step) % 2 == 0:
frames.append(frame)
current_frame += 1
if current_frame % (skip_step * 5) == 0:
progress_msg = f"{progress_text}: {int((current_frame/frame_count_total)*100)}%"
progress_bar.progress(min(current_frame / frame_count_total, 1.0), text=progress_msg)
cap.release()
estimator.close()
progress_bar.empty()
return landmarks_seq, frames
# UI Layout: Two columns for uploads
col1, col2 = st.columns(2)
with col1:
st.markdown('<div class="upload-card">', unsafe_allow_html=True)
st.subheader("1. Video Mẫu (Chuyên gia)")
ref_file = st.file_uploader("Upload video mẫu", type=["mp4", "mov"], key="ref")
st.markdown('</div>', unsafe_allow_html=True)
with col2:
st.markdown('<div class="upload-card">', unsafe_allow_html=True)
st.subheader("2. Video Của Bạn")
user_file = st.file_uploader("Upload video thực hiện", type=["mp4", "mov"], key="user")
st.markdown('</div>', unsafe_allow_html=True)
if ref_file and user_file:
if st.button("🚀 BẮT ĐẦU PHÂN TÍCH"):
# Save temp files
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_ref, \
tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_user:
tmp_ref.write(ref_file.read())
tmp_user.write(user_file.read())
path_ref = tmp_ref.name
path_user = tmp_user.name
try:
# 1. Process Reference Video (with Caching)
with st.status("Đang chuẩn bị video mẫu...") as status:
# Try loading from cache first
ref_landmarks = load_from_cache(path_ref)
if ref_landmarks:
status.update(label="⚡ Đã tải video mẫu từ Cache!", state="complete")
else:
status.update(label="Đang phân tích video mẫu (lần đầu)...")
# Use "heavy" model for reference by default for best quality
ref_landmarks, _ = process_video(path_ref, "Đang trích xuất khung xương mẫu",
model_type="heavy", skip_step=1, resize_width=720)
save_to_cache(path_ref, ref_landmarks)
status.update(label="Xong video mẫu!", state="complete")
# 2. Process User Video (Optimized)
with st.status("Đang phân tích video của bạn...") as status:
user_landmarks, user_frames = process_video(path_user, "Đang trích xuất khung xương của bạn",
model_type=model_type, skip_step=skip_step, resize_width=resize_width)
status.update(label="Xong video người dùng!", state="complete")
# 3. Perform Biomechanics Comparison
st.markdown("---")
st.subheader("📊 Kết Quả Phân Tích")
matcher = BiomechanicsMatcher(ref_landmarks)
path, errors = matcher.compare(user_landmarks)
if path:
score = matcher.get_summary_score(errors)
# Metrics Dashboard
st.markdown("### 📈 Chỉ Số Hiệu Suất")
res_col1, res_col2, res_col3 = st.columns(3)
res_col1.metric("Độ chính xác", f"{int(score)}%", delta=f"{int(score-70)}% so với mục tiêu")
res_col2.metric("Nhịp điệu", "Khá tốt", delta="Đồng bộ cao")
major_errors = len([e for e in errors if e["total"] > 90.0])
res_col3.metric("Lỗi cần sửa", major_errors, delta_color="inverse")
# Detailed Analysis Chart
st.markdown("#### Biểu đồ sai lệch theo thời gian")
error_values = [e["total"] for e in errors]
st.line_chart(error_values)
# Body Part Breakdown
st.markdown("#### Phân tích chi tiết từng bộ phận")
# Calculate avg error per feature (8 angles + 2 hands)
all_diffs = []
for e in errors:
all_diffs.append(np.concatenate([e["angles"], e["hands"]]))
avg_feats = np.mean(all_diffs, axis=0)
from src.core.overlay import FEATURE_LABELS
breakdown_cols = st.columns(5)
for i in range(10):
with breakdown_cols[i % 5]:
err_val = int(avg_feats[i])
status = "✅" if err_val < 30 else "⚠️" if err_val < 60 else "❌"
st.write(f"{status} **{FEATURE_LABELS[i]}**")
st.caption(f"Lệch: {err_val}°")
# 4. Generate Result Video
st.markdown("---")
st.markdown("### 🎥 Video Phân Tích Chi Tiết")
# Cleanup and prepare filename
cleanup_old_videos()
timestamp = int(time.time())
# Sanitize filenames (remove extension and spaces)
ref_name = os.path.splitext(ref_file.name)[0].replace(" ", "_")
user_name = os.path.splitext(user_file.name)[0].replace(" ", "_")
output_video_path = f"outputs/{ref_name}_{user_name}_{score:.2f}_{timestamp}.webm"
from src.core.overlay import generate_result_video
with st.spinner("Đang render video phân tích (High-quality WebM)..."):
generate_result_video(path_user, user_landmarks, path, errors, output_video_path)
if os.path.exists(output_video_path) and os.path.getsize(output_video_path) > 0:
with open(output_video_path, 'rb') as v_file:
video_bytes = v_file.read()
st.video(video_bytes)
st.success("Video đã được AI xử lý. Các vị trí khoanh đỏ là nơi bạn cần tinh chỉnh góc độ.")
else:
st.error("Lỗi khi tạo video kết quả.")
# 5. AI Coach Insights
st.markdown("---")
st.markdown("### 🤖 Lời Khuyên Từ Huấn Luyện Viên AI")
with st.chat_message("assistant", avatar="🥋"):
with st.spinner("AI đang phân tích các lỗi sai và chuẩn bị lời khuyên..."):
feedback = get_coaching_feedback(score, avg_feats, FEATURE_LABELS)
st.markdown(feedback)
st.balloons()
else:
st.error("Không thể đồng bộ hóa hai video. Hãy đảm bảo cả hai đều quay rõ người.")
finally:
# Cleanup
if os.path.exists(path_ref): os.unlink(path_ref)
if os.path.exists(path_user): os.unlink(path_user)
else:
st.info("Vui lòng tải lên cả hai video để bắt đầu so sánh.")