| import hashlib |
| import json |
| import logging |
| import os |
| import queue as _queue |
| import re |
| import subprocess |
| import tempfile |
| import threading |
| import time |
| import warnings |
| import shutil |
| import atexit |
|
|
| import gradio as gr |
| from dotenv import load_dotenv |
| from openai import OpenAI |
|
|
| |
| _script_dir = os.path.dirname(os.path.abspath(__file__)) |
| _temp_dir = os.path.join(_script_dir, ".gradio_temp") |
| os.makedirs(_temp_dir, exist_ok=True) |
| os.environ["GRADIO_TEMP_DIR"] = _temp_dir |
| tempfile.gettempdir = lambda: _temp_dir |
|
|
| load_dotenv() |
| warnings.filterwarnings("ignore") |
| logging.getLogger("httpx").setLevel(logging.WARNING) |
|
|
| from VideoAgent import QueryParam, VideoRAG |
|
|
| |
| custom_css = """ |
| .gradio-container { |
| background: radial-gradient(1000px 320px at 50% -80px, #dbeafe 0%, #f6f8fc 45%, #f6f8fc 100%); |
| color: #0f172a; |
| } |
| .app-title { |
| text-align: center; |
| margin: 6px 0 14px 0; |
| } |
| .app-title h1 { |
| margin: 0; |
| font-size: 25px; |
| font-weight: 700; |
| color: #1e293b; |
| } |
| .card-style { |
| border-radius: 10px !important; |
| border: 1px solid #dde5f1 !important; |
| padding: 12px !important; |
| background: #ffffff !important; |
| box-shadow: 0 3px 12px rgba(30, 41, 59, 0.05); |
| } |
| .section-label { |
| font-weight: 600; |
| color: #1e293b; |
| margin-bottom: 8px; |
| display: flex; |
| align-items: center; |
| font-size: 14px; |
| } |
| .gradio-container .gr-button-primary { |
| background: linear-gradient(135deg, #4f46e5 0%, #2563eb 100%) !important; |
| border: none !important; |
| } |
| .gradio-container .gr-button-secondary { |
| border-color: #cbd5e1 !important; |
| } |
| .console-font textarea { |
| font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, "Courier New", monospace !important; |
| font-size: 12px !important; |
| background: #0f172a !important; |
| color: #e2e8f0 !important; |
| } |
| .video-box { |
| border-radius: 10px !important; |
| overflow: hidden !important; |
| border: 1px solid #d7deeb; |
| } |
| .helper-text { |
| font-size: 12px; |
| color: #64748b; |
| margin: 0; |
| } |
| .search-toolbar { |
| padding: 10px 12px !important; |
| margin-bottom: 4px; |
| } |
| .search-query textarea { |
| font-size: 14px !important; |
| line-height: 1.5 !important; |
| min-height: 68px !important; |
| } |
| .search-actions { |
| margin-top: 6px; |
| justify-content: flex-end; |
| gap: 8px; |
| } |
| .search-actions .gr-button { |
| min-height: 40px !important; |
| font-size: 13px !important; |
| border-radius: 8px !important; |
| min-width: 118px; |
| } |
| .search-panel { |
| margin-top: 4px; |
| } |
| .result-box { |
| border: 1px solid #d7deeb; |
| border-radius: 10px; |
| background: linear-gradient(180deg, #ffffff 0%, #fbfdff 100%); |
| padding: 12px 14px; |
| min-height: 360px; |
| max-height: 360px; |
| overflow: auto; |
| line-height: 1.5; |
| box-shadow: inset 0 1px 0 rgba(255, 255, 255, 0.5); |
| } |
| .result-box h1, .result-box h2, .result-box h3 { |
| margin-top: 0.35em; |
| margin-bottom: 0.35em; |
| } |
| .result-box p { |
| margin: 0.45em 0; |
| } |
| .clip-gallery { |
| border: 1px solid #d7deeb; |
| border-radius: 10px; |
| padding: 6px; |
| background: #ffffff; |
| } |
| .clip-gallery img, .clip-gallery video { |
| border-radius: 8px !important; |
| } |
| .settings-group { |
| margin-bottom: 20px; |
| } |
| .settings-section-title { |
| font-size: 16px !important; |
| font-weight: 600 !important; |
| color: #334155 !important; |
| margin-bottom: 12px !important; |
| padding-bottom: 8px !important; |
| border-bottom: 1px solid #e2e8f0; |
| } |
| .config-card { |
| background: #f8fafc !important; |
| border-radius: 8px !important; |
| padding: 12px !important; |
| border: 1px solid #e2e8f0 !important; |
| margin-bottom: 12px; |
| } |
| .param-row { |
| display: flex !important; |
| gap: 15px !important; |
| margin-bottom: 12px !important; |
| } |
| .param-col { |
| flex: 1 !important; |
| display: flex !important; |
| flex-direction: column !important; |
| } |
| .param-label { |
| font-size: 13px !important; |
| font-weight: 500 !important; |
| color: #475569 !important; |
| margin-bottom: 4px !important; |
| } |
| .param-info { |
| font-size: 11px !important; |
| color: #94a3b8 !important; |
| margin-top: 2px !important; |
| } |
| .apply-btn-container { |
| text-align: center; |
| margin-top: 20px; |
| } |
| .gradio-accordion .label-wrap { |
| padding: 8px 12px !important; |
| } |
| """ |
|
|
| |
| _videorag: VideoRAG | None = None |
| _rag_lock = threading.Lock() |
|
|
| |
| def cleanup_temp_dir(): |
| """清理Gradio临时目录,但保留working_dir中的视频文件""" |
| try: |
| if os.path.exists(_temp_dir): |
| |
| for item in os.listdir(_temp_dir): |
| item_path = os.path.join(_temp_dir, item) |
| |
| if os.path.isfile(item_path): |
| os.remove(item_path) |
| elif os.path.isdir(item_path): |
| |
| shutil.rmtree(item_path) |
| print(f"已清理Gradio临时目录: {_temp_dir}") |
| except Exception as e: |
| print(f"清理临时目录时出错: {e}") |
|
|
| |
| atexit.register(cleanup_temp_dir) |
|
|
| _RAG_ENV_MAP = { |
| "video_segment_length": "VIDEORAG_VIDEO_SEGMENT_LENGTH", |
| "rough_num_frames_per_segment": "VIDEORAG_ROUGH_NUM_FRAMES_PER_SEGMENT", |
| "retrieval_topk_chunks": "VIDEORAG_RETRIEVAL_TOPK_CHUNKS", |
| "query_better_than_threshold": "VIDEORAG_QUERY_BETTER_THAN_THRESHOLD", |
| "chunk_token_size": "VIDEORAG_CHUNK_TOKEN_SIZE", |
| "segment_retrieval_top_k": "VIDEORAG_SEGMENT_RETRIEVAL_TOP_K", |
| } |
|
|
|
|
| def _read_int_env(key: str, default: int) -> int: |
| try: |
| return int(os.getenv(key, str(default)).strip()) |
| except Exception: |
| return default |
|
|
|
|
| def _read_float_env(key: str, default: float) -> float: |
| try: |
| return float(os.getenv(key, str(default)).strip()) |
| except Exception: |
| return default |
|
|
|
|
| def _load_rag_runtime_settings() -> dict: |
| return { |
| "video_segment_length": _read_int_env(_RAG_ENV_MAP["video_segment_length"], 20), |
| "rough_num_frames_per_segment": _read_int_env(_RAG_ENV_MAP["rough_num_frames_per_segment"], 10), |
| "retrieval_topk_chunks": _read_int_env(_RAG_ENV_MAP["retrieval_topk_chunks"], 2), |
| "query_better_than_threshold": _read_float_env(_RAG_ENV_MAP["query_better_than_threshold"], 0.2), |
| "chunk_token_size": _read_int_env(_RAG_ENV_MAP["chunk_token_size"], 1000), |
| "segment_retrieval_top_k": _read_int_env(_RAG_ENV_MAP["segment_retrieval_top_k"], 3), |
| } |
|
|
|
|
| _rag_runtime_settings = _load_rag_runtime_settings() |
|
|
|
|
| def _get_rag(working_dir: str) -> VideoRAG: |
| global _videorag |
| with _rag_lock: |
| need_rebuild = _videorag is None or _videorag.working_dir != working_dir |
| if not need_rebuild and _videorag is not None: |
| for k, v in _rag_runtime_settings.items(): |
| if getattr(_videorag, k, None) != v: |
| need_rebuild = True |
| break |
| if need_rebuild: |
| _videorag = VideoRAG(working_dir=working_dir, **_rag_runtime_settings) |
| return _videorag |
|
|
|
|
| def _read_indexed_videos(working_dir: str) -> list[str]: |
| kv_path = os.path.join(working_dir, "kv_store_video_path.json") |
| if not os.path.exists(kv_path): |
| return [] |
| try: |
| with open(kv_path, "r", encoding="utf-8") as f: |
| data = json.load(f) |
| return list(data.keys()) |
| except Exception: |
| return [] |
|
|
|
|
| def _fmt_video_list(videos: list[str]) -> str: |
| if not videos: |
| return "📦 暂无已索引视频" |
| return "\n".join(f"• {v}" for v in sorted(videos)) |
|
|
|
|
| def _get_path_from_file(file_obj): |
| if file_obj is None: |
| return None |
| if isinstance(file_obj, str): |
| return file_obj |
| if isinstance(file_obj, os.PathLike): |
| return os.fspath(file_obj) |
| if isinstance(file_obj, dict): |
| for k in ("path", "name"): |
| v = file_obj.get(k) |
| if isinstance(v, str) and v.strip(): |
| return v |
| return None |
| for attr in ("path", "name"): |
| v = getattr(file_obj, attr, None) |
| if isinstance(v, str) and v.strip(): |
| return v |
| return None |
|
|
|
|
| def _parse_clock_text(clock_text: str) -> float: |
| t = clock_text.strip() |
| parts = t.split(":") |
| if len(parts) == 2: |
| mm, ss = parts |
| return int(mm) * 60 + int(ss) |
| if len(parts) == 3: |
| hh, mm, ss = parts |
| return int(hh) * 3600 + int(mm) * 60 + int(ss) |
| raise ValueError(f"无效时间格式: {clock_text}") |
|
|
|
|
| _TIME_TOKEN_RE = r"[0-9]{1,2}:[0-9]{1,2}(?::[0-9]{1,2})?" |
| _REF_LINE_RE = re.compile( |
| rf"^\s*(?:[-*•]\s*)?(?:\[(?P<idx>\d+)\]\s*)?(?:\d+[.)、]\s*)?" |
| rf"(?:(?:\*\*)?(?:reference|参考)(?:\*\*)?\s*[::]\s*)?" |
| rf"(?P<video>[^,,]+?)\s*[,,]\s*(?P<start>{_TIME_TOKEN_RE})\s*[,,]\s*(?P<end>{_TIME_TOKEN_RE})\s*$", |
| flags=re.IGNORECASE, |
| ) |
|
|
|
|
| def _is_reference_header(line: str) -> bool: |
| s = line.strip() |
| s = re.sub(r"\*\*", "", s) |
| s = s.replace(":", ":") |
| return bool( |
| re.match(r"^#{1,6}\s*(reference|参考)\s*:?\s*$", s, flags=re.IGNORECASE) |
| or re.match(r"^(reference|参考)\s*:?\s*$", s, flags=re.IGNORECASE) |
| ) |
|
|
|
|
| def _parse_reference_line(line: str): |
| normalized = line.strip().replace(",", ",") |
| if not normalized: |
| return None |
|
|
| |
| tail = re.search(r"(?:reference|参考)\s*[::]\s*(.+)$", normalized, flags=re.IGNORECASE) |
| if tail: |
| normalized = tail.group(1).strip() |
|
|
| idx_text = None |
| m_idx = re.match(r"^\s*\[(\d+)\]\s*(.*)$", normalized) |
| if m_idx: |
| idx_text = m_idx.group(1) |
| normalized = m_idx.group(2).strip() |
|
|
| normalized = re.sub(r"^\s*[-*•]\s*", "", normalized) |
| normalized = re.sub(r"^(?:\d+[.)、])\s*", "", normalized) |
| normalized = re.sub(r"^(?:\*\*)?(?:reference|参考)(?:\*\*)?\s*[::]\s*", "", normalized, flags=re.IGNORECASE) |
|
|
| |
| if normalized.startswith("[") and normalized.endswith("]"): |
| normalized = normalized[1:-1].strip() |
|
|
| m = _REF_LINE_RE.match(normalized) |
| if m: |
| idx_text = idx_text or m.group("idx") |
| video_name = m.group("video").strip().strip("`").strip('"').strip("'").strip("[]") |
| start_text = m.group("start") |
| end_text = m.group("end") |
| else: |
| |
| parts = [p.strip() for p in normalized.split(",") if p.strip()] |
| if len(parts) < 3: |
| return None |
| video_name = ",".join(parts[:-2]).strip().strip("`").strip('"').strip("'").strip("[]") |
| start_text, end_text = parts[-2], parts[-1] |
|
|
| if not video_name: |
| return None |
| try: |
| start = _parse_clock_text(start_text) |
| end = _parse_clock_text(end_text) |
| except Exception: |
| return None |
| if end <= start: |
| end = start + 1 |
|
|
| return { |
| "ref_id": int(idx_text) if idx_text else None, |
| "video_name": video_name, |
| "start": float(start), |
| "end": float(end), |
| "start_text": start_text, |
| "end_text": end_text, |
| } |
|
|
|
|
| def _dedup_and_fill_ref_id(items: list[dict]): |
| uniq = [] |
| seen = set() |
| for it in items: |
| key = (it["video_name"].lower(), int(it["start"]), int(it["end"])) |
| if key in seen: |
| continue |
| seen.add(key) |
| uniq.append(it) |
|
|
| for i, it in enumerate(uniq, start=1): |
| if it["ref_id"] is None: |
| it["ref_id"] = i |
| return sorted(uniq, key=lambda x: x["ref_id"]) |
|
|
|
|
| def _extract_reference_items(answer: str): |
| lines = answer.splitlines() |
| in_ref = False |
| from_ref_section = [] |
| from_all_lines = [] |
|
|
| for raw in lines: |
| line = raw.strip() |
| if _is_reference_header(line): |
| in_ref = True |
| continue |
|
|
| if in_ref and re.match(r"^#{1,6}\s+\S+", line) and not _is_reference_header(line): |
| in_ref = False |
|
|
| parsed = _parse_reference_line(line) |
| if not parsed: |
| continue |
|
|
| from_all_lines.append(parsed) |
| if in_ref: |
| from_ref_section.append(parsed) |
|
|
| |
| if from_ref_section: |
| return _dedup_and_fill_ref_id(from_ref_section) |
| return _dedup_and_fill_ref_id(from_all_lines) |
|
|
|
|
| def _resolve_video_path(rag: VideoRAG, referenced_name: str): |
| def _normalize_video_name(name: str) -> str: |
| n = (name or "").strip().strip("`").strip('"').strip("'").strip("[]") |
| n = n.replace(":", ":") |
| |
| n = re.sub(r"^(video_name|video|视频名|文件名|name)\s*:\s*", "", n, flags=re.IGNORECASE).strip() |
| n = re.sub(r"^\s*[-*•]\s*", "", n) |
| return os.path.splitext(n)[0] |
|
|
| path_map = rag.video_path_db._data |
| normalized = _normalize_video_name(referenced_name) |
| if normalized in path_map: |
| return normalized, path_map[normalized] |
|
|
| raw = (referenced_name or "").strip() |
| raw_wo_ext = _normalize_video_name(raw) |
| for k, v in path_map.items(): |
| if k == raw_wo_ext: |
| return k, v |
| if k.lower() == raw.lower() or k.lower() == raw_wo_ext.lower(): |
| return k, v |
| |
| if raw_wo_ext.lower() in k.lower() or k.lower() in raw_wo_ext.lower(): |
| return k, v |
| return None, None |
|
|
|
|
| def _export_clip(video_path: str, start: float, end: float, working_dir: str, cache_key: str): |
| clip_dir = os.path.join(working_dir, "_webui_query_clips") |
| os.makedirs(clip_dir, exist_ok=True) |
| video_stem = os.path.splitext(os.path.basename(video_path))[0] |
| clip_name = f"{video_stem}_{int(start)}_{int(end)}_{cache_key[:8]}.mp4" |
| clip_path = os.path.join(clip_dir, clip_name) |
|
|
| if os.path.exists(clip_path): |
| return clip_path |
|
|
| cmd = [ |
| "ffmpeg", |
| "-y", |
| "-ss", |
| f"{start:.3f}", |
| "-to", |
| f"{end:.3f}", |
| "-i", |
| video_path, |
| "-c:v", |
| "libx264", |
| "-c:a", |
| "aac", |
| "-movflags", |
| "+faststart", |
| "-loglevel", |
| "error", |
| clip_path, |
| ] |
| proc = subprocess.run(cmd, capture_output=True, text=True) |
| if proc.returncode != 0: |
| raise RuntimeError(proc.stderr.strip() or "ffmpeg 裁剪失败") |
| return clip_path |
|
|
|
|
| def _upsert_env_file(env_path: str, updates: dict[str, str]): |
| key_pattern = re.compile(r"^\s*([A-Za-z_][A-Za-z0-9_]*)\s*=") |
| lines = [] |
| if os.path.exists(env_path): |
| with open(env_path, "r", encoding="utf-8") as f: |
| lines = f.readlines() |
|
|
| found = set() |
| new_lines = [] |
| for line in lines: |
| m = key_pattern.match(line) |
| if not m: |
| new_lines.append(line) |
| continue |
| k = m.group(1) |
| if k in updates: |
| safe_v = str(updates[k]).replace('"', '\\"') |
| new_lines.append(f'{k} = "{safe_v}"\n') |
| found.add(k) |
| else: |
| new_lines.append(line) |
|
|
| for k, v in updates.items(): |
| if k in found: |
| continue |
| safe_v = str(v).replace('"', '\\"') |
| new_lines.append(f'{k} = "{safe_v}"\n') |
|
|
| with open(env_path, "w", encoding="utf-8") as f: |
| f.writelines(new_lines) |
|
|
|
|
| def apply_system_settings( |
| llm_base_url: str, |
| llm_api_key: str, |
| llm_model_name: str, |
| vlm_base_url: str, |
| vlm_api_key: str, |
| vlm_model_name: str, |
| embedding_base_url: str, |
| embedding_api_key: str, |
| embedding_model_name: str, |
| video_segment_length: float, |
| rough_num_frames_per_segment: float, |
| retrieval_topk_chunks: float, |
| query_better_than_threshold: float, |
| chunk_token_size: float, |
| segment_retrieval_top_k: float, |
| ): |
| global _rag_runtime_settings, _videorag |
|
|
| llm_base_url = (llm_base_url or "").strip() |
| llm_api_key = (llm_api_key or "").strip() |
| llm_model_name = (llm_model_name or "").strip() |
| vlm_base_url = (vlm_base_url or "").strip() |
| vlm_api_key = (vlm_api_key or "").strip() |
| vlm_model_name = (vlm_model_name or "").strip() |
| embedding_base_url = (embedding_base_url or "").strip() |
| embedding_api_key = (embedding_api_key or "").strip() |
| embedding_model_name = (embedding_model_name or "").strip() |
|
|
| required = { |
| "LLM_API_BASE_URL": llm_base_url, |
| "LLM_API_KEY": llm_api_key, |
| "LLM_MODEL_NAME": llm_model_name, |
| "VLM_API_BASE_URL": vlm_base_url, |
| "VLM_API_KEY": vlm_api_key, |
| "VLM_MODEL_NAME": vlm_model_name, |
| "EMBEDDING_API_BASE_URL": embedding_base_url, |
| "EMBEDDING_API_KEY": embedding_api_key, |
| "EMBEDDING_MODEL_NAME": embedding_model_name, |
| } |
| empties = [k for k, v in required.items() if not v] |
| if empties: |
| return "❌ 以下配置不能为空:\n- " + "\n- ".join(empties) |
|
|
| try: |
| video_segment_length = int(video_segment_length) |
| rough_num_frames_per_segment = int(rough_num_frames_per_segment) |
| retrieval_topk_chunks = int(retrieval_topk_chunks) |
| chunk_token_size = int(chunk_token_size) |
| query_better_than_threshold = float(query_better_than_threshold) |
| segment_retrieval_top_k = int(segment_retrieval_top_k) |
| except Exception as e: |
| return f"❌ 参数类型错误,请检查数值配置:{e}" |
|
|
| if video_segment_length <= 0 or rough_num_frames_per_segment <= 0: |
| return "❌ video_segment_length 和 rough_num_frames_per_segment 必须 > 0" |
| if retrieval_topk_chunks <= 0: |
| return "❌ retrieval_topk_chunks 必须 > 0" |
| if chunk_token_size <= 0: |
| return "❌ chunk_token_size 必须 > 0" |
| if segment_retrieval_top_k <= 0: |
| return "❌ segment_retrieval_top_k 必须 > 0" |
| if not (0 <= query_better_than_threshold <= 1): |
| return "❌ query_better_than_threshold 建议设置在 [0, 1] 区间" |
|
|
| updates = { |
| **required, |
| _RAG_ENV_MAP["video_segment_length"]: str(video_segment_length), |
| _RAG_ENV_MAP["rough_num_frames_per_segment"]: str(rough_num_frames_per_segment), |
| _RAG_ENV_MAP["retrieval_topk_chunks"]: str(retrieval_topk_chunks), |
| _RAG_ENV_MAP["query_better_than_threshold"]: str(query_better_than_threshold), |
| _RAG_ENV_MAP["chunk_token_size"]: str(chunk_token_size), |
| _RAG_ENV_MAP["segment_retrieval_top_k"]: str(segment_retrieval_top_k), |
| } |
| for k, v in updates.items(): |
| os.environ[k] = v |
|
|
| env_path = os.path.join(_script_dir, ".env") |
| _upsert_env_file(env_path, updates) |
|
|
| applied_logs = [] |
| try: |
| import VideoAgent.query as query_mod |
|
|
| if hasattr(query_mod, "qwen3_model"): |
| query_mod.qwen3_model.model_name = llm_model_name |
| query_mod.qwen3_model.base_url = llm_base_url |
| query_mod.qwen3_model.api_key = llm_api_key |
| query_mod.qwen3_model.client = OpenAI(base_url=llm_base_url, api_key=llm_api_key) |
| applied_logs.append("✅ LLM 客户端已热更新") |
| else: |
| applied_logs.append("⚠️ LLM 客户端未找到,需重启生效") |
| except Exception as e: |
| applied_logs.append(f"⚠️ LLM 热更新失败,需重启生效:{e}") |
|
|
| try: |
| import VideoAgent._videoutil.caption as caption_mod |
|
|
| if hasattr(caption_mod, "model"): |
| caption_mod.model.model_name = vlm_model_name |
| caption_mod.model.client = OpenAI(base_url=vlm_base_url, api_key=vlm_api_key) |
| applied_logs.append("✅ VLM 客户端已热更新") |
| else: |
| applied_logs.append("⚠️ VLM 客户端未找到,需重启生效") |
| except Exception as e: |
| applied_logs.append(f"⚠️ VLM 热更新失败,需重启生效:{e}") |
|
|
| try: |
| from VideoAgent._llm import Qwen3VLEmbedderC |
| import VideoAgent._videoutil.feature as feature_mod |
|
|
| feature_mod.model = Qwen3VLEmbedderC( |
| model_name=embedding_model_name, |
| base_url=embedding_base_url, |
| api_key=embedding_api_key, |
| ) |
| applied_logs.append("✅ Embedding 客户端已热更新") |
| except Exception as e: |
| applied_logs.append(f"⚠️ Embedding 热更新失败,需重启生效:{e}") |
|
|
| _rag_runtime_settings = { |
| "video_segment_length": video_segment_length, |
| "rough_num_frames_per_segment": rough_num_frames_per_segment, |
| "retrieval_topk_chunks": retrieval_topk_chunks, |
| "query_better_than_threshold": query_better_than_threshold, |
| "chunk_token_size": chunk_token_size, |
| "segment_retrieval_top_k": segment_retrieval_top_k, |
| } |
| _videorag = None |
| applied_logs.append("✅ VideoAgent 参数已更新(下次索引/查询将按新参数实例化)") |
|
|
| return ( |
| "🎉 系统设置已保存到 .env 并应用。\n" |
| + "\n".join(applied_logs) |
| + "\n\n若你替换了模型结构差异较大的服务,建议重启 webui 以确保完全生效。" |
| ) |
|
|
|
|
| class _LogCapture(logging.Handler): |
| def __init__(self): |
| super().__init__() |
| self.q: _queue.Queue[str] = _queue.Queue() |
|
|
| def emit(self, record: logging.LogRecord): |
| self.q.put(self.format(record)) |
|
|
|
|
| |
| def refresh_video_list(working_dir: str) -> str: |
| return _fmt_video_list(_read_indexed_videos(working_dir)) |
|
|
|
|
| def index_videos(video_files, working_dir: str, progress=gr.Progress()): |
| if not working_dir.strip(): |
| yield "❌ 错误:工作目录不能为空。", None |
| return |
| if not video_files: |
| yield "❌ 错误:未选择视频文件。", None |
| return |
|
|
| os.makedirs(working_dir, exist_ok=True) |
| |
| |
| processed_dir = os.path.join(working_dir, "processed") |
| os.makedirs(processed_dir, exist_ok=True) |
| |
| |
| existing_videos = set() |
| for file in os.listdir(processed_dir): |
| if file.lower().endswith(('.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv', '.webm')): |
| video_name_without_ext = os.path.splitext(file)[0] |
| existing_videos.add(video_name_without_ext) |
| |
| |
| video_paths = [] |
| original_to_processed_map = {} |
| skipped_videos = [] |
| |
| for f in video_files: |
| original_path = _get_path_from_file(f) |
| if original_path and os.path.exists(original_path): |
| |
| original_video_name = os.path.basename(original_path) |
| video_name_without_ext = os.path.splitext(original_video_name)[0] |
| |
| |
| if video_name_without_ext in existing_videos: |
| skipped_videos.append(original_video_name) |
| continue |
| |
| |
| from VideoAgent._videoutil.split import preprocess_video |
| video_output_path = preprocess_video( |
| original_path, |
| target_width=384, |
| target_height=384, |
| target_fps=5, |
| video_output_format="mp4" |
| ) |
| |
| |
| video_name = os.path.basename(video_output_path) |
| target_path = os.path.join(processed_dir, video_name) |
| |
| |
| if os.path.exists(target_path): |
| name, ext = os.path.splitext(video_name) |
| timestamp = int(time.time()) |
| target_path = os.path.join(processed_dir, f"{name}_{timestamp}{ext}") |
| |
| import shutil |
| shutil.move(video_output_path, target_path) |
| video_paths.append(target_path) |
| original_to_processed_map[original_path] = target_path |
| |
| if skipped_videos: |
| log_lines = [f"⚠️ 以下视频已在processed目录中存在,跳过处理:{', '.join(skipped_videos)}"] |
| else: |
| log_lines = [] |
| |
| if not video_paths: |
| if skipped_videos: |
| yield f"⚠️ 所有视频均已在processed目录中存在,无需重复处理。\n{chr(10).join(log_lines)}", None |
| else: |
| yield "❌ 错误:上传文件不可用,请重新上传后重试。", None |
| return |
|
|
| cap = _LogCapture() |
| cap.setFormatter(logging.Formatter("%(asctime)s | %(message)s", "%H:%M:%S")) |
| root_logger = logging.getLogger() |
| root_logger.addHandler(cap) |
|
|
| events: _queue.Queue = _queue.Queue() |
| done_flag: dict[str, str | bool] = {} |
| total = len(video_paths) |
|
|
| def run(): |
| try: |
| rag = _get_rag(working_dir) |
| for i, path in enumerate(video_paths, start=1): |
| events.put(("start", {"index": i, "path": path})) |
| rag.insert_video(video_path_list=[path]) |
| |
| |
| video_name = os.path.basename(path).split('.')[0] |
| |
| |
| from VideoAgent._utils import always_get_an_event_loop |
| loop = always_get_an_event_loop() |
| |
| async def update_video_path(): |
| await rag.video_path_db.upsert({video_name: path}) |
| await rag.video_path_db.index_done_callback() |
| |
| loop.run_until_complete(update_video_path()) |
| |
| events.put(("done", {"index": i, "path": path})) |
| done_flag["ok"] = True |
| except Exception as e: |
| done_flag["error"] = str(e) |
|
|
| worker = threading.Thread(target=run, daemon=True) |
| worker.start() |
|
|
| current_video = video_paths[0] if video_paths else None |
| log_lines.extend([f"🎬 准备索引 {total} 个视频..."]) |
| last_log_snapshot = "" |
| last_video = None |
|
|
| while worker.is_alive() or not events.empty() or not cap.q.empty(): |
| changed = False |
| while True: |
| try: |
| event_type, value = events.get_nowait() |
| if event_type == "start": |
| i = value["index"] |
| current_video = value["path"] |
| progress(float(i - 1) / total, desc=f"正在索引: {os.path.basename(current_video)}") |
| log_lines.append(f"🎞️ [{i}/{total}] 开始索引:{os.path.basename(current_video)}") |
| elif event_type == "done": |
| i = value["index"] |
| current_video = value["path"] |
| progress(float(i) / total, desc=f"已完成: {os.path.basename(current_video)}") |
| log_lines.append(f"✅ [{i}/{total}] 完成索引:{os.path.basename(current_video)}") |
| else: |
| log_lines.append(value) |
| changed = True |
| except _queue.Empty: |
| break |
|
|
| while True: |
| try: |
| log_line = cap.q.get_nowait() |
| log_lines.append(log_line) |
| changed = True |
| except _queue.Empty: |
| break |
|
|
| snapshot = "\n".join(log_lines[-120:]) |
| if changed or snapshot != last_log_snapshot or current_video != last_video: |
| last_log_snapshot = snapshot |
| last_video = current_video |
| yield snapshot, current_video |
| else: |
| time.sleep(0.2) |
|
|
| root_logger.removeHandler(cap) |
| if skipped_videos: |
| final_log = f"{last_log_snapshot}\n🎉 完成索引。跳过的视频:{', '.join(skipped_videos)}" |
| else: |
| final_log = f"{last_log_snapshot}\n🎉 全部索引完成。" |
| |
| if "error" in done_flag: |
| final_log = f"{last_log_snapshot}\n❌ 索引失败:{done_flag['error']}" |
| yield final_log, current_video |
|
|
|
|
| def query_videos(query_text: str, working_dir: str, progress=gr.Progress()): |
| if not query_text.strip(): |
| return "❌ 请输入问题", [] |
|
|
| try: |
| rag = _get_rag(working_dir) |
| progress(0.2, desc="模型分析中...") |
| qparam = QueryParam() |
| qparam.naive_max_token_for_text_unit = int( |
| getattr(rag, "chunk_token_size", qparam.naive_max_token_for_text_unit) |
| ) |
| answer = str(rag.query(query=query_text, param=qparam)) |
| progress(0.55, desc="解析 Reference 片段...") |
|
|
| refs = _extract_reference_items(answer) |
| if not refs: |
| return ( |
| f"{answer}\n\nℹ️ 未解析到可播放片段(请确保答案包含\"参考/Reference: 视频名, 开始时间, 结束时间\"格式)。", |
| [], |
| ) |
|
|
| gallery_items = [] |
| warnings = [] |
|
|
| total = len(refs) |
| for i, ref in enumerate(refs, start=1): |
| progress(0.55 + (0.40 * i / total), desc=f"处理片段 [{ref['ref_id']}] ...") |
| resolved_name, video_path = _resolve_video_path(rag, ref["video_name"]) |
| if not video_path: |
| warnings.append(f"[{ref['ref_id']}] 未匹配到视频:{ref['video_name']}") |
| continue |
|
|
| cache_key = hashlib.md5( |
| f"{resolved_name}|{ref['start']}|{ref['end']}|{query_text}|{ref['ref_id']}".encode("utf-8") |
| ).hexdigest() |
| try: |
| clip_path = _export_clip( |
| video_path=video_path, |
| start=ref["start"], |
| end=ref["end"], |
| working_dir=working_dir, |
| cache_key=cache_key, |
| ) |
| except Exception as e: |
| warnings.append(f"[{ref['ref_id']}] 裁剪失败:{e}") |
| continue |
|
|
| label = f"[{ref['ref_id']}] {resolved_name} {ref['start_text']} - {ref['end_text']}" |
| |
| gallery_items.append((clip_path, label)) |
|
|
| if not gallery_items: |
| warn_text = ("\n".join(f"- {w}" for w in warnings)) if warnings else "- 未生成任何可播放片段" |
| return ( |
| f"{answer}\n\n⚠️ 参考片段解析完成,但无法生成可播放视频:\n{warn_text}", |
| [], |
| ) |
|
|
| if warnings: |
| answer = answer + "\n\n⚠️ 部分片段处理失败:\n" + "\n".join(f"- {w}" for w in warnings[:5]) |
| progress(1.0, desc="检索完成") |
| return answer, gallery_items |
| except Exception as e: |
| return f"❌ 查询异常: {e}", [] |
|
|
|
|
| |
| DEFAULT_WORKING_DIR = os.path.join(_script_dir, "working_dir") |
|
|
| with gr.Blocks( |
| title="VideoRAG", |
| theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="blue", neutral_hue="slate"), |
| css=custom_css, |
| ) as demo: |
| |
| working_dir_state = gr.State(DEFAULT_WORKING_DIR) |
|
|
| gr.HTML( |
| """ |
| <div class="app-title"> |
| <h1>VideoAgent 控制台</h1> |
| </div> |
| """ |
| ) |
|
|
| with gr.Tabs(): |
| with gr.Tab("索引"): |
| with gr.Row(equal_height=True): |
| with gr.Column(scale=6): |
| with gr.Group(elem_classes="card-style"): |
| gr.HTML('<div class="section-label">上传视频</div>') |
| video_upload = gr.File( |
| label="文件", |
| file_count="multiple", |
| file_types=["video"], |
| type="filepath", |
| height=140, |
| ) |
| with gr.Row(): |
| index_btn = gr.Button("开始索引", variant="primary") |
| refresh_btn = gr.Button("刷新") |
| with gr.Group(elem_classes="card-style"): |
| gr.HTML('<div class="section-label">日志</div>') |
| index_log = gr.Textbox( |
| label="", |
| interactive=False, |
| lines=10, |
| max_lines=14, |
| elem_classes="console-font", |
| placeholder="等待任务...", |
| ) |
|
|
| with gr.Column(scale=4): |
| with gr.Group(elem_classes="card-style"): |
| gr.HTML('<div class="section-label">当前视频</div>') |
| index_video_preview = gr.Video( |
| label="", |
| height=240, |
| interactive=False, |
| elem_classes="video-box", |
| ) |
| with gr.Group(elem_classes="card-style"): |
| gr.HTML('<div class="section-label">已索引</div>') |
| indexed_list = gr.Textbox( |
| label="", |
| value=refresh_video_list(DEFAULT_WORKING_DIR), |
| lines=6, |
| max_lines=8, |
| interactive=False, |
| ) |
|
|
| with gr.Tab("检索"): |
| with gr.Group(elem_classes=["card-style", "search-toolbar"]): |
| query_input = gr.Textbox( |
| label="问题", |
| placeholder="输入检索问题", |
| lines=2, |
| elem_classes="search-query", |
| ) |
| with gr.Row(equal_height=True, elem_classes="search-actions"): |
| query_btn = gr.Button("开始检索", variant="primary", size="sm") |
| clear_query = gr.Button("清空", variant="secondary", size="sm") |
|
|
| with gr.Row(equal_height=False, elem_classes="search-panel"): |
| with gr.Column(scale=6): |
| with gr.Group(elem_classes="card-style"): |
| gr.HTML('<div class="section-label">检索结果</div>') |
| response_box = gr.Markdown("等待检索...", elem_classes="result-box") |
| with gr.Column(scale=6): |
| with gr.Group(elem_classes="card-style"): |
| gr.HTML('<div class="section-label">片段(点击播放)</div>') |
| query_clip_gallery = gr.Gallery( |
| label="", |
| columns=2, |
| height=360, |
| interactive=False, |
| elem_classes="clip-gallery", |
| ) |
|
|
| with gr.Tab("设置"): |
| with gr.Group(elem_classes="card-style"): |
| gr.HTML('<div class="section-label">模型 API 配置</div>') |
| with gr.Accordion("LLM 配置", open=False): |
| with gr.Row(equal_height=True): |
| llm_base_url_input = gr.Textbox( |
| label="Base URL", |
| value=os.getenv("LLM_API_BASE_URL", "http://localhost:8000/v1"), |
| placeholder="https://xxx/v1", |
| ) |
| llm_api_key_input = gr.Textbox( |
| label="API Key", |
| type="password", |
| value=os.getenv("LLM_API_KEY", ""), |
| placeholder="API Key", |
| ) |
| llm_model_name_input = gr.Textbox( |
| label="模型名称", |
| value=os.getenv("LLM_MODEL_NAME", ""), |
| placeholder="model-name", |
| ) |
| |
| with gr.Accordion("VLM 配置", open=False): |
| with gr.Row(equal_height=True): |
| vlm_base_url_input = gr.Textbox( |
| label="Base URL", |
| value=os.getenv("VLM_API_BASE_URL", "http://localhost:8012/v1"), |
| placeholder="https://xxx/v1", |
| ) |
| vlm_api_key_input = gr.Textbox( |
| label="API Key", |
| type="password", |
| value=os.getenv("VLM_API_KEY", ""), |
| placeholder="API Key", |
| ) |
| vlm_model_name_input = gr.Textbox( |
| label="模型名称", |
| value=os.getenv("VLM_MODEL_NAME", ""), |
| placeholder="model-name", |
| ) |
| |
| with gr.Accordion("Embedding 配置", open=False): |
| with gr.Row(equal_height=True): |
| embedding_base_url_input = gr.Textbox( |
| label="Base URL", |
| value=os.getenv("EMBEDDING_API_BASE_URL", "http://localhost:8010/v1"), |
| placeholder="https://xxx/v1", |
| ) |
| embedding_api_key_input = gr.Textbox( |
| label="API Key", |
| type="password", |
| value=os.getenv("EMBEDDING_API_KEY", ""), |
| placeholder="API Key", |
| ) |
| embedding_model_name_input = gr.Textbox( |
| label="模型名称", |
| value=os.getenv("EMBEDDING_MODEL_NAME", ""), |
| placeholder="model-name", |
| ) |
|
|
| with gr.Group(elem_classes="card-style"): |
| gr.HTML('<div class="section-label">VideoAgent 参数配置</div>') |
| |
| |
| with gr.Row(elem_classes="param-row"): |
| with gr.Column(elem_classes="param-col"): |
| video_segment_length_input = gr.Number( |
| label="视频分段长度 (秒)", |
| value=_rag_runtime_settings["video_segment_length"], |
| minimum=1, |
| info="每个视频片段的持续时间", |
| precision=0, |
| ) |
| retrieval_topk_chunks_input = gr.Number( |
| label="检索 Top-K 片段数", |
| value=_rag_runtime_settings["retrieval_topk_chunks"], |
| minimum=1, |
| info="检索相关片段的数量", |
| precision=0, |
| ) |
| with gr.Column(elem_classes="param-col"): |
| rough_num_frames_per_segment_input = gr.Number( |
| label="每段采样帧数", |
| value=_rag_runtime_settings["rough_num_frames_per_segment"], |
| minimum=1, |
| info="每段视频采样的帧数", |
| precision=0, |
| ) |
| segment_retrieval_top_k_input = gr.Number( |
| label="视频段检索 Top-K 数", |
| value=_rag_runtime_settings["segment_retrieval_top_k"], |
| minimum=1, |
| info="检索相关视频段的数量", |
| precision=0, |
| ) |
| |
| with gr.Row(elem_classes="param-row"): |
| with gr.Column(elem_classes="param-col"): |
| query_better_than_threshold_input = gr.Number( |
| label="查询阈值", |
| value=_rag_runtime_settings["query_better_than_threshold"], |
| minimum=0, |
| maximum=1, |
| info="查询匹配的最小阈值", |
| precision=3, |
| ) |
| with gr.Column(elem_classes="param-col"): |
| chunk_token_size_input = gr.Number( |
| label="文本块最大Token数", |
| value=_rag_runtime_settings["chunk_token_size"], |
| minimum=1, |
| info="单个文本块的最大token数量", |
| precision=0, |
| ) |
| |
| apply_settings_btn = gr.Button("保存设置", variant="primary", size="lg", elem_classes="gr-button-primary apply-btn-container") |
| settings_status = gr.Textbox( |
| label="状态信息", |
| lines=3, |
| max_lines=5, |
| interactive=False, |
| placeholder="等待保存...", |
| elem_classes="console-font" |
| ) |
|
|
| |
| index_btn.click( |
| fn=index_videos, |
| inputs=[video_upload, working_dir_state], |
| outputs=[index_log, index_video_preview], |
| ).then( |
| fn=refresh_video_list, |
| inputs=[working_dir_state], |
| outputs=[indexed_list], |
| ) |
|
|
| refresh_btn.click( |
| fn=refresh_video_list, |
| inputs=[working_dir_state], |
| outputs=[indexed_list], |
| ) |
|
|
| query_args = dict( |
| fn=query_videos, |
| inputs=[query_input, working_dir_state], |
| outputs=[response_box, query_clip_gallery], |
| ) |
| query_btn.click(**query_args) |
| query_input.submit(**query_args) |
|
|
| clear_query.click( |
| lambda: ("", "等待检索...", []), |
| None, |
| [query_input, response_box, query_clip_gallery], |
| ) |
|
|
| apply_settings_btn.click( |
| fn=apply_system_settings, |
| inputs=[ |
| llm_base_url_input, |
| llm_api_key_input, |
| llm_model_name_input, |
| vlm_base_url_input, |
| vlm_api_key_input, |
| vlm_model_name_input, |
| embedding_base_url_input, |
| embedding_api_key_input, |
| embedding_model_name_input, |
| video_segment_length_input, |
| rough_num_frames_per_segment_input, |
| retrieval_topk_chunks_input, |
| query_better_than_threshold_input, |
| chunk_token_size_input, |
| segment_retrieval_top_k_input, |
| ], |
| outputs=[settings_status], |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=7869, show_error=True) |
|
|