ASD / app /pipeline.py
Nx-Neuralon's picture
Upload pipeline.py
e147f97 verified
from __future__ import annotations
import os
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Callable, Any
from app.config import Settings
from app.file_utils import ensure_dir, save_json, save_text
from app.llm_client import ApiKeyPool, validate_api_key, mask_key
from app.video_payload import build_video_payload
from app.aggregator import merge_overlapping_results
from app.reporter import build_report_markdown
from app.evidence_builder import build_evidence_summary_with_llm
from app.rag_reporter import build_rag_final_report
from app.retriever import retrieve_chunks, build_retrieve_query_from_evidence
from app.document_utils import build_document_bundle
from app.video_preprocess import preprocess_video
from app.audio_utils import extract_audio_from_video, split_audio_to_chunks
from app.agents.document_agent import DocumentAgent
from app.agents.audio_agent import AudioAgent
from app.agents.look_agent import LookAgent
from app.agents.respond_agent import RespondAgent
from app.agents.point_agent import PointAgent
from app.agents.speak_agent import SpeakAgent
from app.agents.inappropriate_agent import InappropriateAgent
from app.agents.timeline_agent import TimelineAgent
ProgressCb = Callable[[int, str], None]
LogCb = Callable[[str], None]
def normalize_path_arg(p: str | None) -> str | None:
if p is None:
return None
p = p.strip()
while len(p) >= 2 and ((p.startswith('"') and p.endswith('"')) or (p.startswith("'") and p.endswith("'"))):
p = p[1:-1].strip()
return p
class AnalysisPipeline:
def __init__(self, settings: Settings, progress_cb: ProgressCb | None = None, log_cb: LogCb | None = None):
self.settings = settings
self.progress_cb = progress_cb or (lambda p, m: None)
self.log_cb = log_cb or (lambda m: None)
def progress(self, value: int, message: str) -> None:
self.progress_cb(value, message)
self.log_cb(message)
def _make_run_output_dir(self) -> str:
ensure_dir(self.settings.output_dir)
run_name = datetime.now().strftime("run_%Y%m%d_%H%M%S")
run_dir = os.path.join(self.settings.output_dir, run_name)
ensure_dir(run_dir)
return run_dir
def _build_video_agents(self):
settings = self.settings
return [
LookAgent(settings.model),
RespondAgent(settings.model),
# PointAgent(settings.model),
# SpeakAgent(settings.model),
# InappropriateAgent(settings.model),
TimelineAgent(settings.model),
]
def _run_video_agents_serial(self, key_pool: ApiKeyPool, video_payload) -> tuple[list[Any], list[dict[str, str]]]:
agents = self._build_video_agents()
results = []
failed_tasks = []
total = len(agents)
self.log_cb("检测到仅 1 个可用 API Key,视频智能体使用串行模式。")
for idx, agent in enumerate(agents, start=1):
try:
client = key_pool.get_client()
result = agent.run(client=client, video_payload=video_payload)
results.append(result)
self.log_cb(f"[VIDEO] 完成 {agent.agent_name()}(串行 {idx}/{total})")
except Exception as e:
failed_tasks.append({
"agent": agent.agent_name(),
"error": f"{type(e).__name__}: {e}",
})
self.log_cb(f"[WARN] 视频智能体失败: {agent.agent_name()} -> {type(e).__name__}: {e}")
return results, failed_tasks
def _run_video_agents_parallel(self, key_pool: ApiKeyPool, video_payload) -> tuple[list[Any], list[dict[str, str]]]:
agents = self._build_video_agents()
results = []
failed_tasks = []
total = len(agents)
self.log_cb(f"检测到多个可用 API Key,视频智能体使用并发轮转模式(并发上限 {min(self.settings.max_workers, total)})。")
def run_one(agent):
client = key_pool.get_client()
return agent.agent_name(), agent.run(client=client, video_payload=video_payload)
with ThreadPoolExecutor(max_workers=min(self.settings.max_workers, total)) as ex:
futures = {ex.submit(run_one, agent): agent.agent_name() for agent in agents}
done_count = 0
for fut in as_completed(futures):
agent_name = futures[fut]
done_count += 1
try:
_, result = fut.result()
results.append(result)
self.log_cb(f"[VIDEO][{done_count}/{total}] 完成 {agent_name}")
# 让前端看到细粒度进度
self.progress(18 + int(done_count / total * 20), f"视频智能体分析中({done_count}/{total})")
except Exception as e:
failed_tasks.append({
"agent": agent_name,
"error": f"{type(e).__name__}: {e}",
})
self.log_cb(f"[WARN][VIDEO][{done_count}/{total}] 失败 {agent_name} -> {type(e).__name__}: {e}")
self.progress(18 + int(done_count / total * 20), f"视频智能体分析中({done_count}/{total})")
return results, failed_tasks
def _run_video_agents_auto(self, key_pool: ApiKeyPool, video_payload, valid_keys: list[str]) -> tuple[list[Any], list[dict[str, str]]]:
if len(valid_keys) <= 1:
return self._run_video_agents_serial(key_pool, video_payload)
return self._run_video_agents_parallel(key_pool, video_payload)
def _run_video_task(
self,
key_pool: ApiKeyPool,
valid_keys: list[str],
input_video: str | None,
video_url: str | None,
run_output_dir: str,
) -> tuple[list[dict[str, Any]], list[dict[str, str]], str | None]:
settings = self.settings
if settings.video_input_mode == "remote_url":
self.progress(12, "视频支路:使用远程 URL 构建载荷")
self.log_cb("[VIDEO] 使用 remote_url 模式构建视频载荷。")
video_payload = build_video_payload(
video_input_mode=settings.video_input_mode,
video_fps=settings.video_fps,
video_path=None,
video_url=video_url,
video_mime_type=settings.video_mime_type,
)
local_video_for_upload = None
else:
if not input_video:
raise ValueError("base64 模式下,必须提供本地视频路径。")
local_video_for_upload = input_video
if settings.enable_video_preprocess:
self.progress(12, "视频支路:开始预处理视频")
self.log_cb("[VIDEO] 开始本地预处理视频 ...")
pre = preprocess_video(
input_path=input_video,
output_dir=settings.preprocessed_video_dir,
mode=settings.video_preprocess_mode,
remove_audio=settings.video_preprocess_remove_audio,
)
local_video_for_upload = pre.output_path
save_json(os.path.join(run_output_dir, "video_preprocess.json"), {
"input_path": pre.input_path,
"output_path": pre.output_path,
"used_ffmpeg": pre.used_ffmpeg,
"message": pre.message,
"mode": pre.mode,
"ffmpeg_cmd": pre.ffmpeg_cmd,
})
self.progress(16, "视频支路:视频预处理完成")
self.log_cb(f"[VIDEO] 预处理后视频: {pre.output_path}")
else:
self.log_cb("[VIDEO] 跳过视频预处理。")
self.progress(18, "视频支路:构建视频载荷")
self.log_cb("[VIDEO] 构建视频载荷 ...")
video_payload = build_video_payload(
video_input_mode=settings.video_input_mode,
video_fps=settings.video_fps,
video_path=local_video_for_upload,
video_url=None,
video_mime_type=settings.video_mime_type,
)
self.progress(20, "视频支路:开始执行视频智能体")
self.log_cb("[VIDEO] 开始执行视频智能体 ...")
video_results, failed_tasks = self._run_video_agents_auto(
key_pool=key_pool,
video_payload=video_payload,
valid_keys=valid_keys,
)
raw_video_results = [r.model_dump() for r in video_results]
save_json(os.path.join(run_output_dir, "raw_agent_results.json"), raw_video_results)
if failed_tasks:
save_json(os.path.join(run_output_dir, "failed_tasks.json"), failed_tasks)
return raw_video_results, failed_tasks, local_video_for_upload
def _run_audio_task(
self,
key_pool: ApiKeyPool,
valid_keys: list[str],
original_video: str | None,
run_output_dir: str,
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
"""
音频支路直接从原始视频抽音频,不等待视频预处理。
"""
settings = self.settings
if not settings.enable_audio_agent or not original_video:
self.log_cb("[AUDIO] 音频智能体未启用或无原始视频,跳过。")
return [], []
self.log_cb("[AUDIO] 开始抽取音频 ...")
audio_dir = os.path.join(run_output_dir, "audio")
ensure_dir(audio_dir)
extracted_audio_path = os.path.join(audio_dir, "full_audio.mp3")
extract_audio_from_video(
video_path=original_video,
output_audio_path=extracted_audio_path,
bitrate="64k",
)
self.log_cb("[AUDIO] 开始切分音频 ...")
chunks = split_audio_to_chunks(
audio_path=extracted_audio_path,
output_dir=os.path.join(audio_dir, "chunks"),
chunk_seconds=settings.audio_chunk_seconds,
)
self.log_cb(f"[AUDIO] 音频切分完成,共 {len(chunks)} 块。开始 ASR 与语义分析 ...")
audio_agent = AudioAgent(
model=settings.model,
asr_model=settings.audio_asr_model,
)
result, transcripts = audio_agent.run(
key_pool=key_pool,
audio_chunks=chunks,
valid_key_count=len(valid_keys),
max_workers=settings.max_workers,
log_cb=self.log_cb,
)
audio_agent_results = [result.model_dump()]
save_json(os.path.join(run_output_dir, "audio_agent_results.json"), audio_agent_results)
save_json(os.path.join(run_output_dir, "audio_transcripts.json"), transcripts)
self.log_cb("[AUDIO] 音频智能体任务完成。")
return audio_agent_results, transcripts
def _run_document_task(
self,
key_pool: ApiKeyPool,
doc_paths: list[str],
run_output_dir: str,
) -> list[dict[str, Any]]:
if not doc_paths:
self.log_cb("[DOC] 未提供文档,跳过文档智能体。")
return []
self.log_cb("[DOC] 开始执行文档智能体任务 ...")
bundle = build_document_bundle(doc_paths)
save_text(os.path.join(run_output_dir, "document_bundle.txt"), bundle)
doc_agent = DocumentAgent(self.settings.model)
client = key_pool.get_client()
doc_result = doc_agent.run(client=client, document_bundle=bundle)
document_results = [doc_result.model_dump()]
save_json(os.path.join(run_output_dir, "document_agent_results.json"), document_results)
self.log_cb("[DOC] 文档智能体任务完成。")
return document_results
def run(
self,
video: str | None = None,
video_url: str | None = None,
doc_paths: list[str] | None = None,
) -> dict[str, Any]:
video = normalize_path_arg(video)
video_url = normalize_path_arg(video_url)
doc_paths = doc_paths or []
settings = self.settings
run_output_dir = self._make_run_output_dir()
rag_status_path = os.path.join(run_output_dir, "rag_status.json")
rag_raw_path = os.path.join(run_output_dir, "rag_final_raw.json")
rag_error_path = os.path.join(run_output_dir, "rag_final_error.txt")
rag_prompt_path = os.path.join(run_output_dir, "rag_prompt.txt")
retrieved_chunks_path = os.path.join(run_output_dir, "retrieved_chunks.json")
retrieve_query_path = os.path.join(run_output_dir, "retrieve_query.txt")
save_json(rag_status_path, {
"enabled": settings.enable_rag_final,
"entered_rag_branch": False,
"success": False,
"reason": "not_started",
})
save_json(rag_raw_path, {"ok": False, "message": "RAG not executed yet."})
valid_keys = settings.api_keys
# self.progress(5, "预检 API Keys ...")
# valid_keys = []
# invalid_key_logs = []
#
#
# for key in settings.api_keys:
# ok, msg = validate_api_key(
# api_key=key,
# base_url=settings.base_url,
# model=settings.model,
# )
# if ok:
# valid_keys.append(key)
# self.log_cb(f"[OK] API Key 可用: {mask_key(key)}")
# else:
# invalid_key_logs.append({"key": mask_key(key), "error": msg})
# self.log_cb(f"[BAD] API Key 不可用: {mask_key(key)} -> {msg}")
#
# if not valid_keys:
# log_path = os.path.join(run_output_dir, "invalid_keys.json")
# save_json(log_path, invalid_key_logs)
# raise RuntimeError(f"没有任何可用的 API Key。详情见 {log_path}")
key_pool = ApiKeyPool(api_keys=valid_keys, base_url=settings.base_url)
self.progress(10, "并发启动视频、音频、文档支路 ...")
raw_video_results = []
failed_tasks = []
audio_agent_results = []
document_results = []
local_video_for_upload = None
with ThreadPoolExecutor(max_workers=3) as top_ex:
future_map = {
top_ex.submit(self._run_video_task, key_pool, valid_keys, video, video_url, run_output_dir): "video",
top_ex.submit(self._run_audio_task, key_pool, valid_keys, video, run_output_dir): "audio",
top_ex.submit(self._run_document_task, key_pool, doc_paths, run_output_dir): "document",
}
for fut in as_completed(future_map):
kind = future_map[fut]
try:
if kind == "video":
raw_video_results, failed_tasks, local_video_for_upload = fut.result()
self.progress(42, "视频支路完成")
if failed_tasks:
self.log_cb(f"[VIDEO] 视频智能体部分失败,共 {len(failed_tasks)} 项,详情见 failed_tasks.json")
elif kind == "audio":
audio_agent_results, _ = fut.result()
self.progress(50, "音频支路完成")
elif kind == "document":
document_results = fut.result()
self.progress(58, "文档支路完成")
except Exception as e:
self.log_cb(f"[WARN] 顶层并行任务失败: {kind} -> {type(e).__name__}: {e}")
if not raw_video_results:
raise RuntimeError("所有视频智能体都失败。")
# 聚合事件
self.progress(64, "聚合视频事件 ...")
# 这里要从 raw_video_results 恢复 merged 输入,所以直接从文件读或者重构不优雅;
# 重新读取保存后的 raw_agent_results.json 更稳定。
raw_path = os.path.join(run_output_dir, "raw_agent_results.json")
with open(raw_path, "r", encoding="utf-8") as f:
raw_video_results_loaded = json_load_compat(f.read())
# 由于 aggregator 需要 AgentResult 对象,这里不直接反序列化回对象,
# 改为从现有 raw_video_results_loaded 直接构建 merged 的简化逻辑。
# 为保持兼容,读取原始 AgentResult 文件后,通过简单字段聚合实现。
merged_results = merge_from_raw_dicts(raw_video_results_loaded)
merged_path = os.path.join(run_output_dir, "merged_events.json")
save_json(merged_path, merged_results)
self.progress(72, "生成普通报告 ...")
try:
report_client = key_pool.get_client()
report_md = build_report_markdown(
client=report_client,
model=settings.model,
merged_events=merged_results,
)
report_path = os.path.join(run_output_dir, "final_report.md")
save_text(report_path, report_md)
except Exception as e:
self.log_cb(f"[WARN] 普通报告生成失败: {type(e).__name__}: {e}")
self.progress(78, "生成综合证据摘要 ...")
evidence_client = key_pool.get_client()
evidence_summary = build_evidence_summary_with_llm(
client=evidence_client,
model=settings.model,
raw_agent_results=raw_video_results_loaded,
merged_events=merged_results,
audio_agent_results=audio_agent_results,
document_agent_results=document_results,
)
evidence_path = os.path.join(run_output_dir, "evidence_summary.json")
save_json(evidence_path, evidence_summary)
if not settings.enable_rag_final:
save_json(rag_status_path, {
"enabled": False,
"entered_rag_branch": False,
"success": False,
"reason": "rag_disabled_in_env",
})
self.progress(100, "分析完成(未启用 RAG)")
return {
"output_dir": run_output_dir,
"rag_enabled": False,
"document_count": len(doc_paths),
"audio_agent_enabled": settings.enable_audio_agent,
"valid_key_count": len(valid_keys),
"video_agent_mode": "serial" if len(valid_keys) <= 1 else "parallel",
"audio_agent_mode": "serial" if len(valid_keys) <= 1 else "parallel",
}
self.progress(86, "显式检索知识库 ...")
save_json(rag_status_path, {
"enabled": True,
"entered_rag_branch": True,
"success": False,
"reason": "retrieving_chunks",
})
try:
retrieve_query = build_retrieve_query_from_evidence(evidence_summary)
save_text(retrieve_query_path, retrieve_query)
retrieved_chunks = retrieve_chunks(
access_key_id=settings.alibaba_cloud_access_key_id,
access_key_secret=settings.alibaba_cloud_access_key_secret,
workspace_id=settings.bailian_workspace_id,
index_id=settings.bailian_index_id,
query=retrieve_query,
topn=settings.bailian_retrieve_topn,
enable_rerank=settings.bailian_retrieve_enable_rerank,
dense_topk=settings.bailian_retrieve_dense_topk,
sparse_topk=settings.bailian_retrieve_sparse_topk,
min_score=settings.bailian_retrieve_min_score,
)
save_json(retrieved_chunks_path, retrieved_chunks)
except Exception as e:
save_text(rag_error_path, f"Retrieve failed: {type(e).__name__}: {e}")
raise RuntimeError(f"显式检索失败: {type(e).__name__}: {e}")
self.progress(92, "生成最终 RAG 报告 ...")
try:
rag_client = key_pool.get_client()
rag_result = build_rag_final_report(
client=rag_client,
model=settings.model,
evidence_summary=evidence_summary,
retrieved_chunks=retrieved_chunks,
)
save_json(rag_raw_path, rag_result)
save_text(rag_prompt_path, str(rag_result.get("messages", "")))
rag_md_path = os.path.join(run_output_dir, "rag_final_report.md")
save_text(rag_md_path, rag_result["output_text"])
except Exception as e:
save_text(rag_error_path, f"Generate failed: {type(e).__name__}: {e}")
raise RuntimeError(f"RAG 报告生成失败: {type(e).__name__}: {e}")
save_json(rag_status_path, {
"enabled": True,
"entered_rag_branch": True,
"success": True,
"reason": "explicit_retrieve_and_generate_success",
"retrieved_count": len(retrieved_chunks.get("nodes", [])),
})
self.progress(100, "全部分析完成")
return {
"output_dir": run_output_dir,
"rag_enabled": True,
"retrieved_count": len(retrieved_chunks.get("nodes", [])),
"document_count": len(doc_paths),
"audio_agent_enabled": settings.enable_audio_agent,
"valid_key_count": len(valid_keys),
"video_agent_mode": "serial" if len(valid_keys) <= 1 else "parallel",
"audio_agent_mode": "serial" if len(valid_keys) <= 1 else "parallel",
}
def json_load_compat(text: str) -> list[dict[str, Any]]:
import json
return json.loads(text)
def merge_from_raw_dicts(raw_agent_results: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
直接从 raw_agent_results.json 的 dict 结构做聚合,避免再次反构造成 Pydantic 对象。
只聚合视频智能体结果;音频和文档由证据摘要层统一整合。
"""
from collections import defaultdict
grouped = defaultdict(list)
for result in raw_agent_results:
agent_name = result.get("agent_name", "")
for finding in result.get("findings", []):
warning_type = finding.get("warning_type")
if warning_type:
grouped[warning_type].append((agent_name, finding))
merged = []
for warning_type, items in grouped.items():
items.sort(key=lambda x: x[1].get("start_sec", 0))
current = None
for agent_name, f in items:
start_sec = float(f.get("start_sec", 0))
end_sec = float(f.get("end_sec", 0))
confidence = float(f.get("confidence", 0))
evidence = f.get("evidence", "")
behavior_tags = set(f.get("behavior_tags", []))
clinical_note = f.get("clinical_note", "")
if current is None:
current = {
"warning_type": warning_type,
"start_sec": start_sec,
"end_sec": end_sec,
"confidence_sum": confidence,
"count": 1,
"evidences": [evidence],
"sources": [agent_name],
"behavior_tags": set(behavior_tags),
"clinical_note": clinical_note,
}
continue
overlap = start_sec <= current["end_sec"] + 2.0
if overlap:
current["end_sec"] = max(current["end_sec"], end_sec)
current["confidence_sum"] += confidence
current["count"] += 1
current["evidences"].append(evidence)
current["sources"].append(agent_name)
current["behavior_tags"].update(behavior_tags)
if len(clinical_note) > len(current["clinical_note"]):
current["clinical_note"] = clinical_note
else:
merged.append({
"warning_type": current["warning_type"],
"start_sec": current["start_sec"],
"end_sec": current["end_sec"],
"confidence": min(1.0, current["confidence_sum"] / current["count"]),
"evidences": current["evidences"],
"sources": current["sources"],
"behavior_tags": sorted(current["behavior_tags"]),
"clinical_note": current["clinical_note"],
})
current = {
"warning_type": warning_type,
"start_sec": start_sec,
"end_sec": end_sec,
"confidence_sum": confidence,
"count": 1,
"evidences": [evidence],
"sources": [agent_name],
"behavior_tags": set(behavior_tags),
"clinical_note": clinical_note,
}
if current is not None:
merged.append({
"warning_type": current["warning_type"],
"start_sec": current["start_sec"],
"end_sec": current["end_sec"],
"confidence": min(1.0, current["confidence_sum"] / current["count"]),
"evidences": current["evidences"],
"sources": current["sources"],
"behavior_tags": sorted(current["behavior_tags"]),
"clinical_note": current["clinical_note"],
})
merged.sort(key=lambda x: (x["start_sec"], x["warning_type"]))
return merged