yunyixuan commited on
Commit
9a502b8
·
1 Parent(s): 5e85d28

Speed up HF analysis progress reporting

Browse files
RAG/Knowledge_Database/RAGFunc.py CHANGED
@@ -168,6 +168,7 @@ def get_video_ori_keywords(
168
  language='zh',
169
  show=False,
170
  template_keyframes_dir = None,
 
171
  ) -> dict:
172
  """
173
  Multi-stage multimodal assessment for archery posture.
@@ -191,6 +192,13 @@ def get_video_ori_keywords(
191
  return "en"
192
  raise ValueError("language must be 'en' or 'zh'")
193
 
 
 
 
 
 
 
 
194
  def _chat_completion_output_text(resp) -> str:
195
  if resp is None or not getattr(resp, "choices", None):
196
  raise ValueError("Chat completion returned empty response.")
@@ -539,13 +547,22 @@ def get_video_ori_keywords(
539
  image_width=1920,
540
  image_height=1080,
541
  draw_math_feature_points=show,
 
 
 
 
 
 
 
 
542
  )
543
- base_keyframes = extract_keyframes_with_ruptures_poseparts_2d(normalized_data, k=target_k + 3)
544
  key_frame_lists = refine_keyframes_with_absdiff(
545
  video_path=video_path,
546
  keyframe_result=base_keyframes,
547
  k=target_k,
548
  )
 
549
  keyframes_dict = extract_show_keyframes_by_index(video_path, key_frame_lists, show=show)
550
  keyframe_image_items = list(keyframes_dict.get("openai_input_images", []) or [])
551
  if not keyframe_image_items:
@@ -593,6 +610,7 @@ def get_video_ori_keywords(
593
  # + f"- min_x_diff_avg: {metrics['min_x_diff_avg']}"
594
  )
595
  if pipeline == 3:
 
596
  p3_system_prompt = system_prompt + system_metrics_rubric_map[lang]
597
  user_text = ordering_note #+ metrics_text
598
  raw_user_content = [{"type": "input_text", "text": user_text}] + keyframe_image_items
@@ -607,6 +625,7 @@ def get_video_ori_keywords(
607
  return _parse_assessment_payload(_chat_completion_output_text(resp))
608
 
609
  if pipeline == 4:
 
610
  repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
611
  default_template_dir = os.path.join(repo_root, "output_keyframes")
612
  template_dir = template_keyframes_dir if template_keyframes_dir else default_template_dir
@@ -629,6 +648,7 @@ def get_video_ori_keywords(
629
  + template_keyframe_items
630
  )
631
  chat_user_content = _to_chat_content(raw_user_content)
 
632
  resp = client.chat.completions.create(
633
  model=model_name,
634
  messages=[
 
168
  language='zh',
169
  show=False,
170
  template_keyframes_dir = None,
171
+ progress_callback=None,
172
  ) -> dict:
173
  """
174
  Multi-stage multimodal assessment for archery posture.
 
192
  return "en"
193
  raise ValueError("language must be 'en' or 'zh'")
194
 
195
+ def _notify_progress(stage: str, message: str) -> None:
196
+ if callable(progress_callback):
197
+ try:
198
+ progress_callback(stage, message)
199
+ except Exception:
200
+ pass
201
+
202
  def _chat_completion_output_text(resp) -> str:
203
  if resp is None or not getattr(resp, "choices", None):
204
  raise ValueError("Chat completion returned empty response.")
 
547
  image_width=1920,
548
  image_height=1080,
549
  draw_math_feature_points=show,
550
+ progress_callback=progress_callback,
551
+ )
552
+ _notify_progress("selecting_keyframes", "正在筛选关键帧")
553
+ base_keyframes = extract_keyframes_with_ruptures_poseparts_2d(
554
+ normalized_data,
555
+ k=target_k + 3,
556
+ print_all_frame_scores=show,
557
+ print_selection_debug=show,
558
  )
559
+ _notify_progress("refining_keyframes", "正在细化关键帧")
560
  key_frame_lists = refine_keyframes_with_absdiff(
561
  video_path=video_path,
562
  keyframe_result=base_keyframes,
563
  k=target_k,
564
  )
565
+ _notify_progress("rendering_keyframes", "正在整理关键帧输入")
566
  keyframes_dict = extract_show_keyframes_by_index(video_path, key_frame_lists, show=show)
567
  keyframe_image_items = list(keyframes_dict.get("openai_input_images", []) or [])
568
  if not keyframe_image_items:
 
610
  # + f"- min_x_diff_avg: {metrics['min_x_diff_avg']}"
611
  )
612
  if pipeline == 3:
613
+ _notify_progress("generating_assessment", "正在生成动作评估")
614
  p3_system_prompt = system_prompt + system_metrics_rubric_map[lang]
615
  user_text = ordering_note #+ metrics_text
616
  raw_user_content = [{"type": "input_text", "text": user_text}] + keyframe_image_items
 
625
  return _parse_assessment_payload(_chat_completion_output_text(resp))
626
 
627
  if pipeline == 4:
628
+ _notify_progress("loading_template_keyframes", "正在加载模板关键帧")
629
  repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
630
  default_template_dir = os.path.join(repo_root, "output_keyframes")
631
  template_dir = template_keyframes_dir if template_keyframes_dir else default_template_dir
 
648
  + template_keyframe_items
649
  )
650
  chat_user_content = _to_chat_content(raw_user_content)
651
+ _notify_progress("generating_assessment", "正在生成动作评估")
652
  resp = client.chat.completions.create(
653
  model=model_name,
654
  messages=[
RAG/tokenize_search.py CHANGED
@@ -2,7 +2,7 @@ from RAG.Knowledge_Database.RAGFunc import *
2
  from RAG.Knowledge_Database.AIdbconfig import session, session_en
3
  from RAG.Knowledge_Database.AI_dbmanager import KnowledgeDB
4
 
5
- def Tokenize_SearchKeyword(video_path, pipeline=1, subpipeline=3, language='zh',show=False):
6
  """
7
  Extract keywords from video and search knowledge database
8
 
@@ -15,7 +15,13 @@ def Tokenize_SearchKeyword(video_path, pipeline=1, subpipeline=3, language='zh',
15
  List of keywords from knowledge database
16
  """
17
  if pipeline == 1:
18
- answer_content = get_video_ori_keywords(video_path, pipeline=subpipeline, language=language,show=show)
 
 
 
 
 
 
19
  total_score = answer_content["total_score"]
20
  head_score = answer_content["head_score"]
21
  hand_score = answer_content["hand_score"]
@@ -43,6 +49,11 @@ def Tokenize_SearchKeyword(video_path, pipeline=1, subpipeline=3, language='zh',
43
  db = KnowledgeDB(session=session_en if language == 'en' else session)
44
  top_k = 34 if language == 'en' else 34
45
  #TODO 看embedding模型能不能换成qwen-embedding?
 
 
 
 
 
46
  return score_dict, comment, db.from_video_search(query_vec=query_embeddings[0], model_name='ali-text-embedding-v3', top_k=top_k)
47
 
48
  elif pipeline == 2:
@@ -51,4 +62,4 @@ def Tokenize_SearchKeyword(video_path, pipeline=1, subpipeline=3, language='zh',
51
  if language == 'en':
52
  return db.from_video_search(query_vec=video_token, model_name='languagebind', top_k=34)
53
  else:
54
- return db.from_video_search(query_vec=video_token, model_name='languagebind', top_k=17)
 
2
  from RAG.Knowledge_Database.AIdbconfig import session, session_en
3
  from RAG.Knowledge_Database.AI_dbmanager import KnowledgeDB
4
 
5
+ def Tokenize_SearchKeyword(video_path, pipeline=1, subpipeline=3, language='zh', show=False, progress_callback=None):
6
  """
7
  Extract keywords from video and search knowledge database
8
 
 
15
  List of keywords from knowledge database
16
  """
17
  if pipeline == 1:
18
+ answer_content = get_video_ori_keywords(
19
+ video_path,
20
+ pipeline=subpipeline,
21
+ language=language,
22
+ show=show,
23
+ progress_callback=progress_callback,
24
+ )
25
  total_score = answer_content["total_score"]
26
  head_score = answer_content["head_score"]
27
  hand_score = answer_content["hand_score"]
 
49
  db = KnowledgeDB(session=session_en if language == 'en' else session)
50
  top_k = 34 if language == 'en' else 34
51
  #TODO 看embedding模型能不能换成qwen-embedding?
52
+ if callable(progress_callback):
53
+ try:
54
+ progress_callback("retrieving_knowledge", "正在检索技术知识库")
55
+ except Exception:
56
+ pass
57
  return score_dict, comment, db.from_video_search(query_vec=query_embeddings[0], model_name='ali-text-embedding-v3', top_k=top_k)
58
 
59
  elif pipeline == 2:
 
62
  if language == 'en':
63
  return db.from_video_search(query_vec=video_token, model_name='languagebind', top_k=34)
64
  else:
65
+ return db.from_video_search(query_vec=video_token, model_name='languagebind', top_k=17)
RTMPose/Bone_Feature_Extract.py CHANGED
@@ -5,9 +5,13 @@ import ruptures as rpt
5
  import matplotlib.pyplot as plt
6
  import os
7
  import base64
 
8
 
9
  from Tools.Exe_dataset.model_config import model_configs
10
 
 
 
 
11
  COCO133_KPT_IDX = {
12
  # Body 17
13
  "left_shoulder": 5,
@@ -71,6 +75,7 @@ def Keypoint_Extract(
71
  track_conf_kpt_thr=0.2,
72
  draw_tracked_only=True,
73
  draw_math_feature_points=False,
 
74
  ):
75
  from RTMPose.rtmlib import PoseTracker, Wholebody3d, draw_skeleton
76
  """
@@ -209,19 +214,39 @@ def Keypoint_Extract(
209
  )
210
  return frame_bgr
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  # Keep this False by default; turn on only when diagnosing index mismatches.
213
  _maybe_print_coco133_index_self_check(enabled=False)
214
 
215
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
216
  backend = 'onnxruntime'
217
  cap = cv2.VideoCapture(path_to_video)
 
218
 
219
- wholebody3d = PoseTracker(
220
- Wholebody3d,
221
- det_frequency=7,
222
- tracking=False,
223
- backend=backend,
224
- device='cuda' if device.type == 'cuda' else 'cpu')
225
 
226
  frame_idx = -1
227
  whole_skeleton_data = []
 
5
  import matplotlib.pyplot as plt
6
  import os
7
  import base64
8
+ import threading
9
 
10
  from Tools.Exe_dataset.model_config import model_configs
11
 
12
+ _POSE_TRACKER_CACHE = {}
13
+ _POSE_TRACKER_LOCK = threading.Lock()
14
+
15
  COCO133_KPT_IDX = {
16
  # Body 17
17
  "left_shoulder": 5,
 
75
  track_conf_kpt_thr=0.2,
76
  draw_tracked_only=True,
77
  draw_math_feature_points=False,
78
+ progress_callback=None,
79
  ):
80
  from RTMPose.rtmlib import PoseTracker, Wholebody3d, draw_skeleton
81
  """
 
214
  )
215
  return frame_bgr
216
 
217
+ def _notify_progress(stage, message):
218
+ if callable(progress_callback):
219
+ try:
220
+ progress_callback(stage, message)
221
+ except Exception:
222
+ pass
223
+
224
+ def _get_cached_pose_tracker(backend_name, device_name):
225
+ cache_key = (backend_name, device_name)
226
+ with _POSE_TRACKER_LOCK:
227
+ tracker = _POSE_TRACKER_CACHE.get(cache_key)
228
+ if tracker is None:
229
+ tracker = PoseTracker(
230
+ Wholebody3d,
231
+ det_frequency=7,
232
+ tracking=False,
233
+ backend=backend_name,
234
+ device=device_name,
235
+ )
236
+ _POSE_TRACKER_CACHE[cache_key] = tracker
237
+ return tracker
238
+
239
  # Keep this False by default; turn on only when diagnosing index mismatches.
240
  _maybe_print_coco133_index_self_check(enabled=False)
241
 
242
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
243
  backend = 'onnxruntime'
244
  cap = cv2.VideoCapture(path_to_video)
245
+ device_name = 'cuda' if device.type == 'cuda' else 'cpu'
246
 
247
+ _notify_progress("loading_pose_model", "正在加载姿态模型")
248
+ wholebody3d = _get_cached_pose_tracker(backend, device_name)
249
+ _notify_progress("extracting_keypoints", "正在提取人体关键点")
 
 
 
250
 
251
  frame_idx = -1
252
  whole_skeleton_data = []
webapp/backend/app.py CHANGED
@@ -67,16 +67,26 @@ def _score_payload(scores: dict[str, Any]) -> dict[str, Any]:
67
  }
68
 
69
 
70
- def analyze_video(video_path: str, language: str, subpipeline: int) -> dict[str, Any]:
 
 
 
 
 
71
  _ensure_api_key()
 
 
72
  scores, raw_keywords, retrieved_result = Tokenize_SearchKeyword(
73
  video_path=video_path,
74
  pipeline=1,
75
  subpipeline=subpipeline,
76
  language=language,
77
  show=False,
 
78
  )
79
  keywords = _coerce_keywords(raw_keywords)
 
 
80
  assessment_text = get_response(
81
  keywords=keywords,
82
  score_dict=scores,
@@ -116,6 +126,8 @@ class JobRecord:
116
  filename: str
117
  language: str
118
  status: str = "queued"
 
 
119
  created_at: float = field(default_factory=time.time)
120
  updated_at: float = field(default_factory=time.time)
121
  error: str | None = None
@@ -152,12 +164,24 @@ class JobStore:
152
  with self._lock:
153
  job = self._jobs[job_id]
154
  job.status = "running"
 
 
 
 
 
 
 
 
 
 
155
  job.updated_at = time.time()
156
 
157
  def set_completed(self, job_id: str, result: dict[str, Any]) -> None:
158
  with self._lock:
159
  job = self._jobs[job_id]
160
  job.status = "completed"
 
 
161
  job.result = result
162
  job.updated_at = time.time()
163
 
@@ -165,6 +189,8 @@ class JobStore:
165
  with self._lock:
166
  job = self._jobs[job_id]
167
  job.status = "failed"
 
 
168
  job.error = error
169
  job.updated_at = time.time()
170
 
@@ -194,6 +220,8 @@ def _public_job_payload(record: JobRecord) -> dict[str, Any]:
194
  "filename": record.filename,
195
  "language": record.language,
196
  "status": record.status,
 
 
197
  "created_at": record.created_at,
198
  "updated_at": record.updated_at,
199
  "error": record.error,
@@ -213,7 +241,15 @@ def _public_job_payload(record: JobRecord) -> dict[str, Any]:
213
  def _run_job(job_id: str, temp_video_path: str, language: str, subpipeline: int) -> None:
214
  jobs.set_running(job_id)
215
  try:
216
- result = analyze_video(temp_video_path, language=language, subpipeline=subpipeline)
 
 
 
 
 
 
 
 
217
  jobs.set_completed(job_id, result)
218
  except Exception as exc:
219
  jobs.set_failed(job_id, str(exc))
 
67
  }
68
 
69
 
70
+ def analyze_video(
71
+ video_path: str,
72
+ language: str,
73
+ subpipeline: int,
74
+ progress_callback=None,
75
+ ) -> dict[str, Any]:
76
  _ensure_api_key()
77
+ if callable(progress_callback):
78
+ progress_callback("starting", "正在准备分析任务")
79
  scores, raw_keywords, retrieved_result = Tokenize_SearchKeyword(
80
  video_path=video_path,
81
  pipeline=1,
82
  subpipeline=subpipeline,
83
  language=language,
84
  show=False,
85
+ progress_callback=progress_callback,
86
  )
87
  keywords = _coerce_keywords(raw_keywords)
88
+ if callable(progress_callback):
89
+ progress_callback("writing_assessment", "正在生成评估结论")
90
  assessment_text = get_response(
91
  keywords=keywords,
92
  score_dict=scores,
 
126
  filename: str
127
  language: str
128
  status: str = "queued"
129
+ stage: str = "queued"
130
+ status_message: str = "任务已创建,等待处理"
131
  created_at: float = field(default_factory=time.time)
132
  updated_at: float = field(default_factory=time.time)
133
  error: str | None = None
 
164
  with self._lock:
165
  job = self._jobs[job_id]
166
  job.status = "running"
167
+ job.stage = "starting"
168
+ job.status_message = "任务开始执行"
169
+ job.updated_at = time.time()
170
+
171
+ def set_progress(self, job_id: str, stage: str, status_message: str) -> None:
172
+ with self._lock:
173
+ job = self._jobs[job_id]
174
+ job.status = "running"
175
+ job.stage = stage
176
+ job.status_message = status_message
177
  job.updated_at = time.time()
178
 
179
  def set_completed(self, job_id: str, result: dict[str, Any]) -> None:
180
  with self._lock:
181
  job = self._jobs[job_id]
182
  job.status = "completed"
183
+ job.stage = "completed"
184
+ job.status_message = "分析完成,可以继续提问"
185
  job.result = result
186
  job.updated_at = time.time()
187
 
 
189
  with self._lock:
190
  job = self._jobs[job_id]
191
  job.status = "failed"
192
+ job.stage = "failed"
193
+ job.status_message = error
194
  job.error = error
195
  job.updated_at = time.time()
196
 
 
220
  "filename": record.filename,
221
  "language": record.language,
222
  "status": record.status,
223
+ "stage": record.stage,
224
+ "status_message": record.status_message,
225
  "created_at": record.created_at,
226
  "updated_at": record.updated_at,
227
  "error": record.error,
 
241
  def _run_job(job_id: str, temp_video_path: str, language: str, subpipeline: int) -> None:
242
  jobs.set_running(job_id)
243
  try:
244
+ def progress_callback(stage: str, message: str) -> None:
245
+ jobs.set_progress(job_id, stage, message)
246
+
247
+ result = analyze_video(
248
+ temp_video_path,
249
+ language=language,
250
+ subpipeline=subpipeline,
251
+ progress_callback=progress_callback,
252
+ )
253
  jobs.set_completed(job_id, result)
254
  except Exception as exc:
255
  jobs.set_failed(job_id, str(exc))
webapp/static/app.js CHANGED
@@ -29,9 +29,9 @@ function resetResults() {
29
  activeJobId = null;
30
  jobIdEl.textContent = "";
31
  scoresEl.innerHTML = "";
32
- keywordsEl.textContent = "分析完成后显示";
33
  keywordsEl.className = "chips empty";
34
- assessmentEl.textContent = "分析完成后显示";
35
  assessmentEl.className = "answer-box empty";
36
  chatLog.innerHTML = '<div class="message message-system">分析完成后,可以继续和 SEMA 交互。</div>';
37
  questionInput.disabled = true;
@@ -57,7 +57,7 @@ function renderScores(scores) {
57
 
58
  function renderKeywords(keywords) {
59
  if (!keywords || keywords.length === 0) {
60
- keywordsEl.textContent = "无关键词";
61
  keywordsEl.className = "chips empty";
62
  return;
63
  }
@@ -96,7 +96,10 @@ async function pollJob(jobId) {
96
  }
97
 
98
  if (job.status === "queued" || job.status === "running") {
99
- setStatus("running", job.status === "queued" ? "任务已创建,等待处理。" : "正在分析视频,请稍候。");
 
 
 
100
  pollHandle = setTimeout(() => pollJob(jobId), 2500);
101
  return;
102
  }
@@ -139,7 +142,7 @@ analyzeForm.addEventListener("submit", async (event) => {
139
  if (!response.ok) {
140
  throw new Error(payload.detail || "任务创建失败。");
141
  }
142
- setStatus("running", "上传完成,任务已进入队列。");
143
  pollJob(payload.job_id);
144
  } catch (error) {
145
  setStatus("error", error.message || "上传失败。");
 
29
  activeJobId = null;
30
  jobIdEl.textContent = "";
31
  scoresEl.innerHTML = "";
32
+ keywordsEl.textContent = "分析完成后显示";
33
  keywordsEl.className = "chips empty";
34
+ assessmentEl.textContent = "分析完成后显示";
35
  assessmentEl.className = "answer-box empty";
36
  chatLog.innerHTML = '<div class="message message-system">分析完成后,可以继续和 SEMA 交互。</div>';
37
  questionInput.disabled = true;
 
57
 
58
  function renderKeywords(keywords) {
59
  if (!keywords || keywords.length === 0) {
60
+ keywordsEl.textContent = "无关键词";
61
  keywordsEl.className = "chips empty";
62
  return;
63
  }
 
96
  }
97
 
98
  if (job.status === "queued" || job.status === "running") {
99
+ const fallbackText = job.status === "queued"
100
+ ? "任务已创建,等待处理。"
101
+ : "正在分析视频,请稍候。";
102
+ setStatus("running", job.status_message || fallbackText);
103
  pollHandle = setTimeout(() => pollJob(jobId), 2500);
104
  return;
105
  }
 
142
  if (!response.ok) {
143
  throw new Error(payload.detail || "任务创建失败。");
144
  }
145
+ setStatus("running", payload.status_message || "上传完成,任务已进入队列。");
146
  pollJob(payload.job_id);
147
  } catch (error) {
148
  setStatus("error", error.message || "上传失败。");