| from __future__ import annotations |
|
|
| import base64 |
| from contextlib import suppress |
| from pathlib import Path |
| from functools import lru_cache |
| import ast |
| import json |
| import logging |
| import os |
| import re |
| import unicodedata |
| from typing import Any, TypedDict |
|
|
| from dotenv import load_dotenv |
|
|
|
|
| load_dotenv() |
|
|
| logger = logging.getLogger(__name__) |
|
|
| NVIDIA_BASE_URL = "https://integrate.api.nvidia.com/v1" |
| DEFAULT_NVIDIA_MODEL = "deepseek-ai/deepseek-v4-flash" |
| DEFAULT_NVIDIA_VISION_MODEL = "meta/llama-3.2-11b-vision-instruct" |
| DEFAULT_NVIDIA_VISION_MAX_TOKENS = 1200 |
| MAX_FEEDBACK_CHARS = 850 |
|
|
|
|
| class FeedbackState(TypedDict, total=False): |
| analysis: dict[str, Any] |
| evidence: dict[str, Any] |
| feedback: str |
| warnings: list[str] |
|
|
|
|
| def _env_int(name: str, default: int) -> int: |
| raw_value = os.getenv(name, "").strip() |
| if not raw_value: |
| return default |
| try: |
| return int(raw_value) |
| except ValueError: |
| return default |
|
|
|
|
| def generate_coach_feedback(analysis_result: dict[str, Any]) -> str: |
| """Generate a short Vietnamese coaching note from deterministic analysis. |
| |
| The graph is intentionally narrow: MediaPipe/rules decide the errors, the |
| LLM only rewrites those facts into actionable coaching language. |
| """ |
| fallback = _fallback_feedback(analysis_result) |
| api_key = os.getenv("NVIDIA_API_KEY", "").strip() |
| if not api_key or api_key == "nvapi-...": |
| return fallback |
|
|
| try: |
| graph = _compiled_graph() |
| final_state = graph.invoke( |
| { |
| "analysis": analysis_result, |
| "feedback": fallback, |
| "warnings": [], |
| } |
| ) |
| return final_state.get("feedback") or fallback |
| except Exception: |
| return fallback |
|
|
|
|
| def generate_rep_visual_feedback( |
| *, |
| student_image_path: str, |
| expert_image_path: str, |
| rep_context: dict[str, Any], |
| ) -> dict[str, Any]: |
| """Use a VLM to inspect one student/expert frame pair. |
| |
| Returns compact text feedback. Arrow placement is handled downstream from |
| rule-based landmarks so it stays anchored to the student's body. |
| """ |
| api_key = os.getenv("NVIDIA_API_KEY", "").strip() |
| if not api_key or api_key == "nvapi-...": |
| return _rep_visual_fallback(rep_context, "missing NVIDIA_API_KEY") |
|
|
| try: |
| comparison_image_path: Path | None = None |
| try: |
| comparison_image_path = _make_comparison_image(student_image_path, expert_image_path) |
| parsed = _call_and_parse_vlm_for_rep( |
| comparison_image_path=str(comparison_image_path), |
| rep_context=rep_context, |
| api_key=api_key, |
| ) |
| finally: |
| if comparison_image_path is not None: |
| with suppress(OSError): |
| comparison_image_path.unlink(missing_ok=True) |
| return _validate_rep_visual_feedback(parsed, rep_context) |
| except Exception as exc: |
| logger.warning("VLM rep feedback failed: %s", exc) |
| return _rep_visual_fallback(rep_context, str(exc)) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def _compiled_graph(): |
| from langgraph.graph import END, START, StateGraph |
|
|
| graph = StateGraph(FeedbackState) |
| graph.add_node("build_evidence", _build_evidence) |
| graph.add_node("generate_feedback", _generate_feedback) |
| graph.add_node("validate_feedback", _validate_feedback) |
| graph.add_edge(START, "build_evidence") |
| graph.add_edge("build_evidence", "generate_feedback") |
| graph.add_edge("generate_feedback", "validate_feedback") |
| graph.add_edge("validate_feedback", END) |
| return graph.compile() |
|
|
|
|
| def _build_evidence(state: FeedbackState) -> FeedbackState: |
| analysis = state["analysis"] |
| rep_results = analysis.get("rep_results", []) |
| weakest_reps = sorted(rep_results, key=lambda rep: rep.get("score_pct", 100))[:3] |
|
|
| evidence = { |
| "exercise": analysis.get("exercise_label", "Hít đất"), |
| "overall_score_pct": analysis.get("overall_score_pct", 0), |
| "student_reps": analysis.get("student_reps", 0), |
| "expert_reps": analysis.get("expert_reps", 0), |
| "good_reps": analysis.get("good_reps", 0), |
| "serious_reps": analysis.get("serious_reps", 0), |
| "summary": analysis.get("summary", ""), |
| "main_errors": [ |
| { |
| "label": err.get("label", ""), |
| "count": err.get("count", 0), |
| "severity": err.get("severity", ""), |
| "guidance": err.get("guidance", ""), |
| } |
| for err in analysis.get("main_errors", [])[:3] |
| ], |
| "weakest_reps": [ |
| { |
| "rep_num": rep.get("rep_num"), |
| "score_pct": rep.get("score_pct"), |
| "status": rep.get("status"), |
| "errors": rep.get("error_labels", []), |
| "feedback": rep.get("feedback", ""), |
| } |
| for rep in weakest_reps |
| ], |
| } |
| return {**state, "evidence": evidence} |
|
|
|
|
| def _generate_feedback(state: FeedbackState) -> FeedbackState: |
| from openai import OpenAI |
|
|
| client = OpenAI( |
| base_url=os.getenv("NVIDIA_BASE_URL", NVIDIA_BASE_URL), |
| api_key=os.getenv("NVIDIA_API_KEY"), |
| ) |
| model = os.getenv("NVIDIA_MODEL", DEFAULT_NVIDIA_MODEL) |
| prompt = _feedback_prompt(state["evidence"]) |
| completion = client.chat.completions.create( |
| model=model, |
| messages=[ |
| { |
| "role": "system", |
| "content": ( |
| "Bạn là huấn luyện viên thể hình. Chỉ dùng dữ liệu đã cung cấp, " |
| "không bịa lỗi mới. Trả lời tiếng Việt, ngắn gọn, tối đa 150 chữ." |
| ), |
| }, |
| {"role": "user", "content": prompt}, |
| ], |
| temperature=0.25, |
| top_p=0.9, |
| max_tokens=500, |
| ) |
| feedback = (completion.choices[0].message.content or "").strip() |
| return {**state, "feedback": feedback} |
|
|
|
|
| def _call_vlm_for_rep( |
| *, |
| comparison_image_path: str, |
| rep_context: dict[str, Any], |
| api_key: str, |
| retry: bool = False, |
| ) -> str: |
| from openai import OpenAI |
|
|
| client = OpenAI( |
| base_url=os.getenv("NVIDIA_BASE_URL", NVIDIA_BASE_URL), |
| api_key=api_key, |
| ) |
| model = os.getenv("NVIDIA_VISION_MODEL") or DEFAULT_NVIDIA_VISION_MODEL |
| completion = client.chat.completions.create( |
| model=model, |
| messages=[ |
| { |
| "role": "system", |
| "content": ( |
| "Bạn là huấn luyện viên hít đất. Bạn nhận 2 nguồn thông tin: " |
| "1) ảnh rep của học viên và ảnh mẫu mentor, 2) kết quả rule-based có thể sai. " |
| "Ưu tiên đánh giá trực tiếp từ ảnh. Nếu rule-based mâu thuẫn với ảnh, hãy theo ảnh. " |
| "Chỉ trả về JSON hợp lệ, không markdown. " |
| "Giữ mọi string value thật ngắn để JSON không bị cắt giữa chừng." |
| ), |
| }, |
| { |
| "role": "user", |
| "content": [ |
| { |
| "type": "text", |
| "text": _rep_vlm_prompt(rep_context, retry=retry), |
| }, |
| { |
| "type": "image_url", |
| "image_url": {"url": _image_as_data_url(comparison_image_path)}, |
| }, |
| ], |
| }, |
| ], |
| temperature=0.15, |
| top_p=0.9, |
| max_tokens=_env_int("NVIDIA_VISION_MAX_TOKENS", DEFAULT_NVIDIA_VISION_MAX_TOKENS), |
| ) |
| if not completion.choices: |
| raise ValueError("empty VLM choices") |
|
|
| choice = completion.choices[0] |
| content = _extract_message_text(choice.message) |
| if not content: |
| finish_reason = getattr(choice, "finish_reason", "") |
| raise ValueError(f"empty VLM content; finish_reason={finish_reason}; model={model}") |
| finish_reason = getattr(choice, "finish_reason", "") |
| if finish_reason == "length": |
| raise ValueError( |
| "VLM response was truncated before JSON completed; " |
| "increase NVIDIA_VISION_MAX_TOKENS or use a shorter vision model response" |
| ) |
| return content.strip() |
|
|
|
|
| def _call_and_parse_vlm_for_rep( |
| *, |
| comparison_image_path: str, |
| rep_context: dict[str, Any], |
| api_key: str, |
| ) -> dict[str, Any]: |
| first_error = "" |
| for retry in (False, True): |
| try: |
| response_text = _call_vlm_for_rep( |
| comparison_image_path=comparison_image_path, |
| rep_context=rep_context, |
| api_key=api_key, |
| retry=retry, |
| ) |
| parsed = _parse_json_object(response_text) |
| if not retry and _needs_vlm_retry(parsed): |
| raise ValueError("VLM returned incomplete feedback") |
| return parsed |
| except Exception as exc: |
| first_error = first_error or str(exc) |
| if retry: |
| raise ValueError(f"invalid VLM JSON after retry: {exc}") from exc |
| logger.info("Retrying VLM rep feedback after parse failure: %s", exc) |
|
|
| raise ValueError(first_error or "unknown VLM parse failure") |
|
|
|
|
| def _extract_message_text(message: Any) -> str: |
| content = getattr(message, "content", "") |
| if isinstance(content, str): |
| return content |
| if isinstance(content, list): |
| parts = [] |
| for item in content: |
| if isinstance(item, dict): |
| text = item.get("text") or item.get("content") |
| if isinstance(text, str): |
| parts.append(text) |
| else: |
| text = getattr(item, "text", None) |
| if isinstance(text, str): |
| parts.append(text) |
| return "\n".join(parts) |
| return str(content or "") |
|
|
|
|
| def _validate_feedback(state: FeedbackState) -> FeedbackState: |
| feedback = (state.get("feedback") or "").strip() |
| fallback = _fallback_feedback(state["analysis"]) |
| if not feedback: |
| return {**state, "feedback": fallback} |
|
|
| if len(feedback) > MAX_FEEDBACK_CHARS: |
| feedback = feedback[:MAX_FEEDBACK_CHARS].rsplit(" ", 1)[0].strip() + "..." |
|
|
| allowed_labels = { |
| err.get("label", "").lower() |
| for err in state["analysis"].get("main_errors", []) |
| if err.get("label") |
| } |
| if allowed_labels and _mentions_unknown_error(feedback, allowed_labels): |
| return {**state, "feedback": fallback} |
|
|
| return {**state, "feedback": feedback} |
|
|
|
|
| def _feedback_prompt(evidence: dict[str, Any]) -> str: |
| return f""" |
| Bạn là huấn luyện viên thể hình cho người mới tập. |
| Dữ liệu phân tích bên dưới đã được tạo bởi hệ thống computer vision và rule engine. |
| |
| Yêu cầu: |
| - Chỉ dùng lỗi có trong dữ liệu, không tự bịa lỗi mới. |
| - Viết tiếng Việt, ngắn gọn, không quá 150 chữ. |
| - Có 4 phần: Nhận xét, Lỗi chính, Cách sửa, Bài tập bổ trợ. |
| - Nếu không có lỗi nghiêm trọng, tập trung khen form và nhắc giữ nhịp kiểm soát. |
| - Không nhắc tới JSON, AI, model, MediaPipe, DTW. |
| |
| Dữ liệu: |
| {json.dumps(evidence, ensure_ascii=False, indent=2)} |
| """.strip() |
|
|
|
|
| def _rep_vlm_prompt(rep_context: dict[str, Any], *, retry: bool = False) -> str: |
| retry_instruction = "" |
| if retry: |
| retry_instruction = """ |
| Lần gọi trước không trả về JSON hợp lệ. Lần này chỉ trả về đúng MỘT object JSON minified. |
| Không dùng markdown, không giải thích, không thêm chữ trước/sau JSON. |
| Tất cả key và string value phải dùng dấu nháy kép. |
| """.strip() |
|
|
| return f""" |
| Bạn nhận MỘT ảnh so sánh ghép ngang: |
| - Nửa TRÁI là rep của học viên. |
| - Nửa PHẢI là rep mẫu của mentor. |
| |
| Nguồn thông tin 1 - hình ảnh: |
| - So sánh trực tiếp tư thế học viên với mẫu: đầu/cổ, vai, lưng, hông, khuỷu tay, độ sâu khi xuống. |
| - Đây là nguồn ưu tiên cao nhất. |
| |
| Nguồn thông tin 2 - rule-based: |
| - Có thể đúng hoặc sai do góc quay/landmark nhiễu. |
| - Chỉ dùng như gợi ý phụ, không được chép lại nếu ảnh không ủng hộ. |
| |
| Rule context: |
| {json.dumps(rep_context, ensure_ascii=False, indent=2)} |
| |
| Trả về JSON đúng schema: |
| {{ |
| "is_error": true, |
| "visual_error_label": "tên lỗi chính quan sát từ ảnh hoặc 'Không có lỗi rõ ràng'", |
| "diagnosis": "nhận xét ngắn về rep này dựa trên ảnh", |
| "correction": "lời khuyên sửa cụ thể cho lần tập tiếp theo, hoặc 'Tiếp tục giữ form hiện tại' nếu đã đúng", |
| "feedback": "ghép diagnosis + correction thành 1-2 câu tiếng Việt cụ thể cho rep này" |
| }} |
| |
| Quy tắc: |
| - Không trả về tọa độ arrow. Arrow sẽ được hệ thống đặt từ landmark rule-based. |
| - diagnosis, correction, feedback đều phải ngắn. Feedback tối đa 2 câu. |
| - Feedback không được chỉ là tên lỗi. Phải nói lỗi ảnh hưởng gì và sửa bằng hành động cụ thể. |
| - Không trả lời ngoài JSON. |
| {retry_instruction} |
| """.strip() |
|
|
|
|
| def _image_as_data_url(path: str) -> str: |
| with open(path, "rb") as file: |
| encoded = base64.b64encode(file.read()).decode("ascii") |
| return f"data:image/jpeg;base64,{encoded}" |
|
|
|
|
| def _make_comparison_image(student_image_path: str, expert_image_path: str) -> Path: |
| import cv2 |
| import numpy as np |
|
|
| student = cv2.imread(student_image_path) |
| expert = cv2.imread(expert_image_path) |
| if student is None: |
| raise ValueError(f"Cannot read student image: {student_image_path}") |
| if expert is None: |
| raise ValueError(f"Cannot read expert image: {expert_image_path}") |
|
|
| target_height = min(student.shape[0], expert.shape[0], 720) |
|
|
| def resize_to_height(image): |
| ratio = target_height / image.shape[0] |
| width = max(1, int(image.shape[1] * ratio)) |
| return cv2.resize(image, (width, target_height), interpolation=cv2.INTER_AREA) |
|
|
| student_resized = resize_to_height(student) |
| expert_resized = resize_to_height(expert) |
| separator = np.full((target_height, 10, 3), 255, dtype=np.uint8) |
| combined = np.hstack([student_resized, separator, expert_resized]) |
| output_path = Path(student_image_path).with_name(Path(student_image_path).stem + "_vlm_compare.jpg") |
| cv2.imwrite(str(output_path), combined, [int(cv2.IMWRITE_JPEG_QUALITY), 88]) |
| return output_path |
|
|
|
|
| def _parse_json_object(text: str) -> dict[str, Any]: |
| cleaned = _strip_json_wrappers(text) |
| candidates = [ |
| cleaned, |
| _first_balanced_json_object(cleaned), |
| ] |
| for candidate in candidates: |
| if not candidate: |
| continue |
| parsed = _try_parse_json_object(candidate) |
| if parsed is not None: |
| return parsed |
|
|
| preview = cleaned.replace("\n", " ")[:220] |
| raise ValueError(f"VLM returned non-JSON content: {preview or '<empty>'}") |
|
|
|
|
| def _strip_json_wrappers(text: str) -> str: |
| cleaned = (text or "").strip() |
| if cleaned.startswith("```"): |
| cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned, flags=re.IGNORECASE) |
| cleaned = re.sub(r"\s*```$", "", cleaned) |
| return cleaned.strip() |
|
|
|
|
| def _first_balanced_json_object(text: str) -> str: |
| start = text.find("{") |
| if start < 0: |
| return "" |
|
|
| depth = 0 |
| in_string = False |
| escape = False |
| for index in range(start, len(text)): |
| char = text[index] |
| if escape: |
| escape = False |
| continue |
| if char == "\\": |
| escape = True |
| continue |
| if char == '"': |
| in_string = not in_string |
| continue |
| if in_string: |
| continue |
| if char == "{": |
| depth += 1 |
| elif char == "}": |
| depth -= 1 |
| if depth == 0: |
| return text[start : index + 1] |
| return text[start:] |
|
|
|
|
| def _try_parse_json_object(text: str) -> dict[str, Any] | None: |
| for parser in (json.loads, ast.literal_eval): |
| try: |
| parsed = parser(text) |
| except (json.JSONDecodeError, ValueError, SyntaxError): |
| continue |
| if isinstance(parsed, dict): |
| return parsed |
| return None |
|
|
|
|
| def _validate_rep_visual_feedback( |
| parsed: dict[str, Any], |
| rep_context: dict[str, Any], |
| ) -> dict[str, Any]: |
| diagnosis = str(parsed.get("diagnosis") or "").strip() |
| correction = str(parsed.get("correction") or "").strip() |
| feedback = str(parsed.get("feedback") or "").strip() |
| if len(feedback) < 35 and (diagnosis or correction): |
| feedback = " ".join(part for part in (diagnosis, correction) if part) |
| if not feedback: |
| return _rep_visual_fallback(rep_context, "empty VLM feedback") |
|
|
| return { |
| "source": "vlm", |
| "is_error": bool(parsed.get("is_error", True)), |
| "visual_error_label": str(parsed.get("visual_error_label") or "").strip(), |
| "feedback": feedback[:MAX_FEEDBACK_CHARS], |
| "arrow": None, |
| "error": "", |
| } |
|
|
|
|
| def _rep_visual_fallback(rep_context: dict[str, Any], reason: str = "") -> dict[str, Any]: |
| return { |
| "source": "fallback", |
| "is_error": bool(rep_context.get("rule_errors")), |
| "visual_error_label": "", |
| "feedback": "", |
| "arrow": None, |
| "error": reason, |
| } |
|
|
|
|
| def _needs_vlm_retry(parsed: dict[str, Any]) -> bool: |
| feedback = str(parsed.get("feedback") or parsed.get("diagnosis") or "").strip() |
| return len(feedback) < 20 |
|
|
|
|
| def _is_usable_arrow_point(x: float, y: float) -> bool: |
| |
| return 0.03 <= x <= 0.97 and 0.03 <= y <= 0.97 |
|
|
|
|
| def _safe_arrow_label(label: str) -> str: |
| ascii_label = unicodedata.normalize("NFKD", label).encode("ascii", "ignore").decode("ascii") |
| cleaned = re.sub(r"[^A-Za-z0-9 _-]+", "", ascii_label).strip() |
| return (cleaned or "Can sua")[:18] |
|
|
|
|
| def _fallback_feedback(analysis: dict[str, Any]) -> str: |
| main_errors = analysis.get("main_errors", []) |
| if not main_errors: |
| return ( |
| "Nhận xét: Form tổng thể ổn định, chưa phát hiện lỗi nghiêm trọng.\n" |
| "Lỗi chính: Chưa có lỗi nổi bật.\n" |
| "Cách sửa: Tiếp tục giữ thân người thẳng và kiểm soát nhịp xuống-lên.\n" |
| "Bài tập bổ trợ: Plank 2 hiệp, mỗi hiệp 30 giây." |
| ) |
|
|
| top_error = main_errors[0] |
| return ( |
| f"Nhận xét: Bạn hoàn thành {analysis.get('student_reps', 0)} rep với điểm " |
| f"{analysis.get('overall_score_pct', 0)}%.\n" |
| f"Lỗi chính: {top_error.get('label', 'Form chưa ổn định')} xuất hiện " |
| f"{top_error.get('count', 0)} lần.\n" |
| f"Cách sửa: {top_error.get('guidance', 'Giữ nhịp chậm và kiểm soát thân người tốt hơn.')}\n" |
| "Bài tập bổ trợ: Plank 2 hiệp, mỗi hiệp 30 giây trước khi tập lại." |
| ) |
|
|
|
|
| def _mentions_unknown_error(feedback: str, allowed_labels: set[str]) -> bool: |
| known_error_words = { |
| "võng lưng": "võng lưng", |
| "nhô mông": "nhô mông", |
| "mông quá cao": "nhô mông", |
| "cúi đầu": "cúi đầu", |
| "gập cổ": "cúi đầu", |
| "chưa hạ": "chưa hạ", |
| "chưa xuống": "chưa hạ", |
| "cơ thể chưa giữ thẳng": "cơ thể chưa giữ thẳng", |
| } |
| text = feedback.lower() |
| for phrase, normalized in known_error_words.items(): |
| if phrase in text and not any(normalized in label for label in allowed_labels): |
| return True |
| return False |
|
|