diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..515f2fc0415714ea09bd7833adbbbde824f31393 --- /dev/null +++ b/.gitignore @@ -0,0 +1,12 @@ +backend/.venv/ +backend/data/ +backend/models/ +backend/config/ +backend/note_results/ +backend/static/ +backend/uploads/ +backend/*.db +backend/app/db/*.db +__pycache__/ +*.pyc +.env diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..0be0dc895bc0f6b4644047f7d9e0a46bdb8c7ac1 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,48 @@ +# VideoMemo 后端 —— Hugging Face Spaces(Docker SDK)部署用 Dockerfile。 +# +# 用法:HF Space 是一个独立 git 仓库,把它的根目录布置成: +# /Dockerfile ← 本文件(复制到 Space 根目录,重命名为 Dockerfile) +# /README.md ← deploy/hf-space/README.md(含 HF 必需的 frontmatter) +# /backend/... ← 从本项目复制整个 backend 目录过去 +# 然后 git push 到 Space,HF 会构建本文件(COPY 路径相对 Space 根目录)。 +# +# 镜像故意精简:只装 ffmpeg + 后端依赖,默认走 REST 飞书推送,不装 lark-cli。 +# 数据库用外接 Postgres(Supabase),通过 DATABASE_URL Secret 注入。 +ARG BASE_REGISTRY=docker.io +FROM ${BASE_REGISTRY}/library/python:3.11-slim + +# HF 在 huggingface.co 自家基础设施上构建/运行:用官方 PyPI 与默认 HF 端点, +# 不要用国内镜像(那会更慢甚至失败)。 +ARG PIP_INDEX=https://pypi.org/simple + +# fonts-liberation 提供与 Arial 度量兼容的 LiberationSans,替代仓库里的 arial.ttf +# (HF git 不收二进制,故字体不进仓库,改由镜像在构建时提供) +RUN apt-get update && \ + apt-get install -y --no-install-recommends ffmpeg curl fonts-liberation && \ + rm -rf /var/lib/apt/lists/* + +ENV PYTHONUNBUFFERED=1 \ + BACKEND_HOST=0.0.0.0 \ + BACKEND_PORT=8483 \ + STATIC=/static \ + OUT_DIR=/app/static/screenshots \ + IMAGE_BASE_URL=/static/screenshots \ + NOTE_OUTPUT_DIR=/app/data/note_results \ + DATA_DIR=/app/data + +WORKDIR /app + +# 先装依赖利用层缓存 +COPY backend/requirements.txt /app/requirements.txt +RUN pip install --no-cache-dir -i ${PIP_INDEX} -r requirements.txt + +# 再复制后端代码 +COPY backend /app + +# 预建可写目录(HF 容器以 root 运行,这些目录是临时盘——重启会清空, +# 所以结构化数据务必走外接 DATABASE_URL;笔记/截图属临时数据,后续可再迁对象存储) +RUN mkdir -p /app/data/note_results /app/static/screenshots /app/config /app/fonts && \ + cp /usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf /app/fonts/arial.ttf + +EXPOSE 8483 +CMD ["python", "main.py"] diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4bddac1fe26cf7a8a645f0a5e647caf5bec80fc1 --- /dev/null +++ b/README.md @@ -0,0 +1,20 @@ +--- +title: VideoMemo Backend +emoji: 🎬 +colorFrom: indigo +colorTo: blue +sdk: docker +app_port: 8483 +pinned: false +--- + +# VideoMemo 后端(API) + +AI 视频笔记生成的后端服务。桌面端 / 网页端 / 浏览器插件连接本 Space 的地址使用。 + +- **结构化数据**(LLM 供应商配置与 API key、模型、关键词订阅、通知渠道、任务索引) + 持久化到外接 Postgres(Supabase),通过 `DATABASE_URL` Secret 配置。 +- **本 Space 公开可访问**:务必设置 `WEB_ACCESS_PASSWORD` Secret,否则任何人都能调用你的后端。 +- 笔记正文 / 截图 / 向量库当前仍是容器内临时文件,**重启会清空**(计划后续迁入 Postgres / 对象存储)。 + +> 部署步骤见仓库 `deploy/hf-space/DEPLOY.md`。 diff --git a/backend/.env.example b/backend/.env.example new file mode 100644 index 0000000000000000000000000000000000000000..0b37d5fc8fed3cb3822f83bc710df0d6e00ab589 --- /dev/null +++ b/backend/.env.example @@ -0,0 +1,12 @@ + +# 通用 +ENV=production +API_BASE_URL=http://127.0.0.1:8000 +SCREENSHOT_BASE_URL=http://127.0.0.1:8000/static/screenshots +STATIC=/static # 外部访问路径(URL 前缀) +OUT_DIR=./static/screenshots # 本地输出目录 +IMAGE_BASE_URL=/static/screenshots # 图片访问 URL +DATA_DIR=data +# transcriber 相关配置 +TRANSCRIBER_TYPE=fast-whisper # fast-whisper/bcut/kuaishou +WHISPER_MODEL_SIZE=base \ No newline at end of file diff --git a/backend/Dockerfile b/backend/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..bd0befb653dadc6178569025b91c353cbdf2bbf0 --- /dev/null +++ b/backend/Dockerfile @@ -0,0 +1,42 @@ +# BASE_REGISTRY 默认走 docker.io;国内拉不到 docker.io 时可换 daocloud / 阿里云 / 自建镜像源: +# docker-compose build --build-arg BASE_REGISTRY=docker.m.daocloud.io +# 或写到 docker-compose.yml 的 build.args / 环境变量里 +ARG BASE_REGISTRY=docker.io +FROM ${BASE_REGISTRY}/library/python:3.11-slim + +ARG APT_MIRROR=mirrors.tuna.tsinghua.edu.cn +ARG PIP_INDEX=https://pypi.tuna.tsinghua.edu.cn/simple + +RUN rm -f /etc/apt/sources.list && \ + rm -rf /etc/apt/sources.list.d/* && \ + echo "deb https://${APT_MIRROR}/debian bookworm main contrib non-free non-free-firmware" > /etc/apt/sources.list && \ + echo "deb https://${APT_MIRROR}/debian bookworm-updates main contrib non-free non-free-firmware" >> /etc/apt/sources.list && \ + echo "deb https://${APT_MIRROR}/debian-security bookworm-security main contrib non-free non-free-firmware" >> /etc/apt/sources.list && \ + apt-get update && \ + apt-get install -y --no-install-recommends ffmpeg curl && \ + rm -rf /var/lib/apt/lists/* + +ENV PATH="/usr/bin:${PATH}" +ENV HF_ENDPOINT=https://hf-mirror.com + +# 飞书「推送方式 = lark-cli / auto」时需要官方 lark CLI(npm 包 @larksuite/cli,二进制名 lark-cli)。 +# 走 REST 直连推送则用不到,可按需删除本段以瘦身镜像。 +# 凭证通过 LARK_APP_ID / LARK_APP_SECRET 环境变量在运行时注入(由后端调用时传入),此处不写死。 +ARG NPM_REGISTRY=https://registry.npmmirror.com +RUN apt-get update && \ + apt-get install -y --no-install-recommends nodejs npm && \ + npm config set registry ${NPM_REGISTRY} && \ + npm install -g @larksuite/cli && \ + rm -rf /var/lib/apt/lists/* /root/.npm && \ + (lark-cli --version || true) + +WORKDIR /app + +# 先复制 requirements.txt 利用层缓存 +COPY ./backend/requirements.txt /app/requirements.txt +RUN pip install --no-cache-dir -i ${PIP_INDEX} -r requirements.txt + +# 再复制应用代码(频繁变动不影响 pip 缓存层) +COPY ./backend /app + +CMD ["python", "main.py"] diff --git a/backend/Dockerfile.gpu b/backend/Dockerfile.gpu new file mode 100644 index 0000000000000000000000000000000000000000..59a760765e9a759dfbca6d3102474166daf8da2e --- /dev/null +++ b/backend/Dockerfile.gpu @@ -0,0 +1,40 @@ +# BASE_REGISTRY 默认走 docker.io;国内可换 daocloud / 阿里云镜像(注意所选镜像需支持 nvidia/cuda 命名空间) +ARG BASE_REGISTRY=docker.io +FROM ${BASE_REGISTRY}/nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04 + +ARG APT_MIRROR=mirrors.tuna.tsinghua.edu.cn +ARG PIP_INDEX=https://pypi.tuna.tsinghua.edu.cn/simple + +RUN rm -f /etc/apt/sources.list && \ + rm -rf /etc/apt/sources.list.d/* && \ + echo "deb https://${APT_MIRROR}/ubuntu jammy main restricted universe multiverse" > /etc/apt/sources.list && \ + echo "deb https://${APT_MIRROR}/ubuntu jammy-updates main restricted universe multiverse" >> /etc/apt/sources.list && \ + echo "deb https://${APT_MIRROR}/ubuntu jammy-security main restricted universe multiverse" >> /etc/apt/sources.list && \ + apt-get update && \ + apt-get install -y --no-install-recommends ffmpeg python3-pip curl && \ + rm -rf /var/lib/apt/lists/* + +ENV HF_ENDPOINT=https://hf-mirror.com + +# 飞书「推送方式 = lark-cli / auto」时需要官方 lark CLI(npm 包 @larksuite/cli,二进制名 lark-cli)。 +# Ubuntu 22.04 自带 apt 的 Node 太旧(v12)跑不动新 CLI,这里用 NodeSource 装 Node 20。 +# 走 REST 直连推送则用不到,可按需删除本段以瘦身镜像。凭证由后端运行时经环境变量注入,不写死。 +ARG NPM_REGISTRY=https://registry.npmmirror.com +RUN curl -fsSL https://deb.nodesource.com/setup_20.x | bash - && \ + apt-get install -y --no-install-recommends nodejs && \ + npm config set registry ${NPM_REGISTRY} && \ + npm install -g @larksuite/cli && \ + rm -rf /var/lib/apt/lists/* /root/.npm && \ + (lark-cli --version || true) + +WORKDIR /app + +# 先复制 requirements.txt 利用层缓存 +COPY ./backend/requirements.txt /app/requirements.txt +RUN pip install --no-cache-dir -i ${PIP_INDEX} -r requirements.txt && \ + pip install --no-cache-dir -i ${PIP_INDEX} 'transformers[torch]>=4.23' + +# 再复制应用代码 +COPY ./backend /app + +CMD ["python3", "main.py"] diff --git a/backend/__init__.py b/backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backend/app/__init__.py b/backend/app/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..357498661f101e8084bb0f64d750326f95e244eb --- /dev/null +++ b/backend/app/__init__.py @@ -0,0 +1,47 @@ +import os +from typing import Optional + +from fastapi import Depends, FastAPI, Header, HTTPException, Request + +# 健康/诊断类接口:公网前端在用户尚未填访问密码时,也要能判断后端是否可达、 +# 从而正常加载页面(否则启动探测被密码拦成 401,整页卡在「连接中」无法进入设置去填密码)。 +_AUTH_EXEMPT_PATHS = {"/api/sys_check", "/api/sys_health", "/api/deploy_status"} + + +async def verify_web_access_password( + request: Request, + request_web_access_password: Optional[str] = Header( + None, alias="request-web-access-password" + ) +): + if request.url.path in _AUTH_EXEMPT_PATHS: + return True + expected = os.getenv("WEB_ACCESS_PASSWORD") + if expected and request_web_access_password != expected: + raise HTTPException(status_code=401, detail="访问密码错误或未填写") + return True + +def create_app(lifespan) -> FastAPI: + from .routers import note, notification, provider, model, config, chat, flashcard, hot_videos, article, trend_subscription, feishu + from .utils.response import ResponseWrapper as R + + app = FastAPI(title="VideoMemo",lifespan=lifespan) + protected = [Depends(verify_web_access_password)] + + @app.get("/sys_check") + async def root_sys_check(): + return R.success() + + app.include_router(note.router, prefix="/api", dependencies=protected) + app.include_router(provider.router, prefix="/api", dependencies=protected) + app.include_router(model.router, prefix="/api", dependencies=protected) + app.include_router(config.router, prefix="/api", dependencies=protected) + app.include_router(chat.router, prefix="/api", dependencies=protected) + app.include_router(flashcard.router, prefix="/api", dependencies=protected) + app.include_router(hot_videos.router, prefix="/api", dependencies=protected) + app.include_router(article.router, prefix="/api", dependencies=protected) + app.include_router(trend_subscription.router, prefix="/api", dependencies=protected) + app.include_router(notification.router, prefix="/api", dependencies=protected) + app.include_router(feishu.router, prefix="/api", dependencies=protected) + + return app diff --git a/backend/app/article_fetchers/__init__.py b/backend/app/article_fetchers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..78e8be2dc9ec721c480830226ef809841e38056b --- /dev/null +++ b/backend/app/article_fetchers/__init__.py @@ -0,0 +1,3 @@ +from app.article_fetchers.base import ArticleContent, ArticleFetcher, ArticleFetchError + +__all__ = ["ArticleContent", "ArticleFetcher", "ArticleFetchError"] diff --git a/backend/app/article_fetchers/base.py b/backend/app/article_fetchers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..c022464dcb3f8bb1ecf7d29b1ac72a4a602df31f --- /dev/null +++ b/backend/app/article_fetchers/base.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Protocol + + +@dataclass +class ArticleContent: + platform: str + url: str + article_id: str + title: str + author_name: str = "" + author_id: str = "" + content_text: str = "" + image_urls: list[str] = field(default_factory=list) + cover_url: str = "" + published_at: str = "" + raw_metadata: dict = field(default_factory=dict) + + +class ArticleFetchError(Exception): + pass + + +class ArticleFetcher(Protocol): + platform: str + + def fetch(self, url: str) -> ArticleContent: + ... + + def search(self, keyword: str, limit: int = 20) -> list[ArticleContent]: + ... + + def fetch_publisher(self, query: str, limit: int = 20) -> list[ArticleContent]: + ... diff --git a/backend/app/article_fetchers/generic.py b/backend/app/article_fetchers/generic.py new file mode 100644 index 0000000000000000000000000000000000000000..d59134aac9f963765359402c1748d2d9fca024bd --- /dev/null +++ b/backend/app/article_fetchers/generic.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +import re +from urllib.parse import urlparse + +import requests +from bs4 import BeautifulSoup + +from app.article_fetchers.base import ArticleContent, ArticleFetchError +from app.utils.url_parser import clean_url + + +def _clean_text(value: str) -> str: + return re.sub(r"[ \t\r\f\v]+", " ", value or "").strip() + + +def _normalize_body(value: str) -> str: + lines = [_clean_text(line) for line in (value or "").splitlines()] + return "\n".join(line for line in lines if line) + + +def _meta_content(soup: BeautifulSoup, *selectors: tuple[str, str]) -> str: + for attr, value in selectors: + node = soup.find("meta", attrs={attr: value}) + if node: + content = _clean_text(node.get("content") or "") + if content: + return content + return "" + + +def _candidate_score(node) -> int: + text = _normalize_body(node.get_text("\n")) + paragraphs = node.find_all("p") + return len(text) + len(paragraphs) * 120 + + +def parse_generic_article_html(html: str, url: str) -> ArticleContent: + soup = BeautifulSoup(html, "html.parser") + for tag in soup(["script", "style", "noscript", "svg", "canvas", "iframe"]): + tag.decompose() + for tag in soup(["nav", "header", "footer", "aside", "form"]): + tag.decompose() + + title = ( + _meta_content(soup, ("property", "og:title"), ("name", "twitter:title")) + or _clean_text(soup.title.get_text(" ")) if soup.title else "" + ) + author = _meta_content(soup, ("name", "author"), ("property", "article:author")) + published_at = _meta_content( + soup, + ("property", "article:published_time"), + ("name", "publishdate"), + ("name", "date"), + ) + cover = _meta_content(soup, ("property", "og:image"), ("name", "twitter:image")) + + candidates = [] + for selector in ("article", "main", "[role='main']", "#content", ".content", ".article", ".post"): + candidates.extend(soup.select(selector)) + if not candidates and soup.body: + candidates = [soup.body] + best = max(candidates, key=_candidate_score, default=None) + body = _normalize_body(best.get_text("\n")) if best else "" + if len(body) < 80: + description = _meta_content(soup, ("name", "description"), ("property", "og:description")) + body = description if len(description) > len(body) else body + if len(body) < 40: + raise ValueError("网页正文为空或过短,无法生成总结") + + parsed = urlparse(url) + article_id = parsed.netloc + parsed.path + return ArticleContent( + platform="generic_web", + url=url, + article_id=article_id or url, + title=title or parsed.netloc or "网页文章", + author_name=author, + content_text=body, + image_urls=[cover] if cover else [], + cover_url=cover, + published_at=published_at, + raw_metadata={"source": "generic_web"}, + ) + + +class GenericArticleFetcher: + platform = "generic_web" + + def fetch(self, url: str) -> ArticleContent: + clean = clean_url(url) + try: + response = requests.get( + clean, + timeout=12, + allow_redirects=True, + headers={ + "User-Agent": ( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " + "AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0 Safari/537.36" + ), + "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", + "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8", + }, + ) + response.raise_for_status() + return parse_generic_article_html(response.text, response.url or clean) + except ValueError: + raise + except Exception as exc: + raise ArticleFetchError(f"网页文章抓取失败:{exc}") from exc + + def search(self, keyword: str, limit: int = 20) -> list[ArticleContent]: + raise ArticleFetchError("通用网页暂不支持关键字查询,请粘贴具体文章链接") + + def fetch_publisher(self, query: str, limit: int = 20) -> list[ArticleContent]: + raise ArticleFetchError("通用网页暂不支持发布者订阅,请粘贴具体文章链接") diff --git a/backend/app/article_fetchers/wechat.py b/backend/app/article_fetchers/wechat.py new file mode 100644 index 0000000000000000000000000000000000000000..c0c57a5b68b244960a1fc018d84250fa2bc1fa7c --- /dev/null +++ b/backend/app/article_fetchers/wechat.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import re +from urllib.parse import parse_qs, quote, unquote, urljoin, urlparse + +import requests +from bs4 import BeautifulSoup + +from app.article_fetchers.base import ArticleContent, ArticleFetchError + + +def _clean_text(value: str) -> str: + return re.sub(r"\s+", " ", value or "").strip() + + +def _element_text(element) -> str: + return _clean_text(element.get_text(" ")) if element else "" + + +def _script_value(html: str, name: str) -> str: + patterns = [ + rf'var\s+{re.escape(name)}\s*=\s*"([^"]*)"', + rf"{re.escape(name)}\s*:\s*'([^']*)'", + ] + for pattern in patterns: + match = re.search(pattern, html) + if match: + return match.group(1).strip() + return "" + + +def parse_wechat_article_html(html: str, url: str) -> ArticleContent: + soup = BeautifulSoup(html, "html.parser") + title = _element_text(soup.find(id="activity-name") or soup.find("h1")) + author = _element_text(soup.find(id="js_name")) + published_at = _element_text(soup.find(id="publish_time")) + content = soup.find(id="js_content") + body = _clean_text(content.get_text("\n")) if content else "" + if not body: + raise ValueError("微信公众号文章正文为空,无法生成总结") + + image_urls: list[str] = [] + for image in content.find_all("img") if content else []: + src = image.get("data-src") or image.get("src") or "" + if src and src not in image_urls: + image_urls.append(src) + + biz = _script_value(html, "biz") + mid = _script_value(html, "mid") + idx = _script_value(html, "idx") + sn = _script_value(html, "sn") + article_id = ":".join(part for part in [biz, mid, idx, sn] if part) or url + + return ArticleContent( + platform="wechat_mp", + url=url, + article_id=article_id, + title=title or "微信公众号文章", + author_name=author, + author_id=biz, + content_text=body, + image_urls=image_urls, + cover_url=image_urls[0] if image_urls else "", + published_at=published_at, + raw_metadata={"biz": biz, "mid": mid, "idx": idx, "sn": sn}, + ) + + +def _normalize_wechat_result_url(href: str) -> str: + if not href: + return "" + absolute = urljoin("https://weixin.sogou.com", href) + parsed = urlparse(absolute) + query = parse_qs(parsed.query) + for key in ("url", "target"): + if query.get(key): + candidate = unquote(query[key][0]) + if "mp.weixin.qq.com" in candidate: + return candidate + return absolute if "mp.weixin.qq.com" in absolute else "" + + +def parse_wechat_search_html(html: str, keyword: str, limit: int = 20) -> list[ArticleContent]: + soup = BeautifulSoup(html, "html.parser") + items: list[ArticleContent] = [] + seen: set[str] = set() + for anchor in soup.find_all("a", href=True): + url = _normalize_wechat_result_url(anchor.get("href") or "") + if not url or url in seen: + continue + title = _clean_text(anchor.get_text(" ")) + if not title: + continue + container = anchor.find_parent(["div", "li"]) or anchor.parent + info_nodes = container.find_all(class_=re.compile(r"(txt-info|s-p|account)")) if container else [] + info = [_clean_text(node.get_text(" ")) for node in info_nodes if _clean_text(node.get_text(" "))] + author = info[0] if info else "" + summary = info[-1] if len(info) > 1 else title + seen.add(url) + items.append( + ArticleContent( + platform="wechat_mp", + url=url, + article_id=url, + title=title, + author_name=author, + content_text=summary, + raw_metadata={"keyword": keyword, "source": "sogou_weixin"}, + ) + ) + if len(items) >= limit: + break + return items + + +class WechatArticleFetcher: + platform = "wechat_mp" + + def fetch(self, url: str) -> ArticleContent: + try: + response = requests.get(url, timeout=10, headers={"User-Agent": "Mozilla/5.0"}) + response.raise_for_status() + return parse_wechat_article_html(response.text, url) + except ValueError: + raise + except Exception as exc: + raise ArticleFetchError(f"微信公众号文章抓取失败:{exc}") from exc + + def search(self, keyword: str, limit: int = 20) -> list[ArticleContent]: + try: + response = requests.get( + f"https://weixin.sogou.com/weixin?type=2&query={quote(keyword)}", + timeout=10, + headers={"User-Agent": "Mozilla/5.0"}, + ) + response.raise_for_status() + return parse_wechat_search_html(response.text, keyword, limit) + except Exception as exc: + raise ArticleFetchError(f"微信公众号关键字查询失败:{exc}") from exc + + def fetch_publisher(self, query: str, limit: int = 20) -> list[ArticleContent]: + return self.search(query, limit) diff --git a/backend/app/article_fetchers/xiaohongshu.py b/backend/app/article_fetchers/xiaohongshu.py new file mode 100644 index 0000000000000000000000000000000000000000..4be625c153304829196e1287a3a9895d9110b903 --- /dev/null +++ b/backend/app/article_fetchers/xiaohongshu.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +import json +import re +from datetime import datetime +from urllib.parse import quote, urlparse + +import requests +from bs4 import BeautifulSoup + +from app.article_fetchers.base import ArticleContent, ArticleFetchError +from app.services.cookie_manager import CookieConfigManager +from app.utils.url_parser import clean_url + + +def _note_id_from_url(url: str) -> str: + path = urlparse(url).path.rstrip("/") + return path.split("/")[-1] if path else url + + +def _extract_initial_state(html: str) -> dict: + match = re.search(r"window\.__INITIAL_STATE__\s*=", html) + if not match: + return {} + start = html.find("{", match.end()) + if start < 0: + return {} + depth = 0 + end = -1 + for index in range(start, len(html)): + char = html[index] + if char == "{": + depth += 1 + elif char == "}": + depth -= 1 + if depth == 0: + end = index + 1 + break + if end < 0: + return {} + raw = html[start:end].replace("undefined", "null") + try: + return json.loads(raw) + except json.JSONDecodeError: + return {} + + +def _first_image_url(item: dict) -> str: + for key in ("urlDefault", "url", "traceId"): + value = item.get(key) + if isinstance(value, str) and value.startswith("http"): + return value + nested = item.get("cover") or item.get("image") or {} + if isinstance(nested, dict): + for key in ("urlDefault", "url"): + value = nested.get(key) + if isinstance(value, str) and value.startswith("http"): + return value + return "" + + +def _published_at(value) -> str: + try: + timestamp = int(value) + except (TypeError, ValueError): + return "" + if timestamp > 10_000_000_000: + timestamp = timestamp // 1000 + return datetime.fromtimestamp(timestamp).isoformat(timespec="seconds") + + +def _article_from_note(note: dict, url: str) -> ArticleContent: + user = note.get("user") or {} + images: list[str] = [] + for image in note.get("imageList") or note.get("images") or []: + src = _first_image_url(image) + if src and src not in images: + images.append(src) + + content = str(note.get("desc") or note.get("description") or "").strip() + title = str(note.get("title") or "").strip() or content[:40] or "小红书笔记" + article_id = str(note.get("noteId") or note.get("id") or _note_id_from_url(url)).strip() + if not content: + raise ValueError("小红书笔记正文为空,无法生成总结") + + return ArticleContent( + platform="xiaohongshu", + url=url, + article_id=article_id, + title=title, + author_name=str(user.get("nickname") or "").strip(), + author_id=str(user.get("userId") or user.get("id") or "").strip(), + content_text=content, + image_urls=images, + cover_url=images[0] if images else "", + published_at=_published_at(note.get("time") or note.get("lastUpdateTime")), + raw_metadata={"raw_note": note}, + ) + + +def parse_xiaohongshu_article_html(html: str, url: str) -> ArticleContent: + state = _extract_initial_state(html) + detail_map = ((state.get("note") or {}).get("noteDetailMap")) or {} + for value in detail_map.values(): + note = value.get("note") if isinstance(value, dict) else None + if isinstance(note, dict): + return _article_from_note(note, url) + + soup = BeautifulSoup(html, "html.parser") + title_meta = soup.find("meta", attrs={"property": "og:title"}) + desc_meta = soup.find("meta", attrs={"name": "description"}) + title = (title_meta.get("content") if title_meta else "") or "小红书笔记" + body = (desc_meta.get("content") if desc_meta else "").strip() + if not body: + raise ValueError("小红书笔记正文为空,无法生成总结") + + return ArticleContent( + platform="xiaohongshu", + url=url, + article_id=_note_id_from_url(url), + title=title.strip(), + content_text=body, + ) + + +def _iter_note_like(value): + if isinstance(value, dict): + note_id = value.get("noteId") or value.get("id") + title = value.get("title") or value.get("displayTitle") + desc = value.get("desc") or value.get("description") + if note_id and (title or desc): + yield value + for child in value.values(): + yield from _iter_note_like(child) + elif isinstance(value, list): + for child in value: + yield from _iter_note_like(child) + + +def parse_xiaohongshu_discovery_html( + html: str, + source_url: str, + limit: int = 20, +) -> list[ArticleContent]: + state = _extract_initial_state(html) + items: list[ArticleContent] = [] + seen: set[str] = set() + for note in _iter_note_like(state): + article_id = str(note.get("noteId") or note.get("id") or "").strip() + if not article_id or article_id in seen: + continue + user = note.get("user") or note.get("author") or {} + image_url = _first_image_url(note) + content = str(note.get("desc") or note.get("description") or note.get("title") or "").strip() + title = str(note.get("title") or note.get("displayTitle") or content[:40] or "小红书笔记").strip() + seen.add(article_id) + items.append( + ArticleContent( + platform="xiaohongshu", + url=f"https://www.xiaohongshu.com/explore/{article_id}", + article_id=article_id, + title=title, + author_name=str(user.get("nickname") or user.get("name") or "").strip(), + author_id=str(user.get("userId") or user.get("id") or "").strip(), + content_text=content, + image_urls=[image_url] if image_url else [], + cover_url=image_url, + raw_metadata={"source_url": source_url}, + ) + ) + if len(items) >= limit: + break + return items + + +class XiaohongshuArticleFetcher: + platform = "xiaohongshu" + + def __init__(self): + self._cookie_mgr = CookieConfigManager() + + def _headers(self) -> dict: + headers = {"User-Agent": "Mozilla/5.0"} + cookie = self._cookie_mgr.get("xiaohongshu") + if cookie: + headers["Cookie"] = cookie + return headers + + def fetch(self, url: str) -> ArticleContent: + clean = clean_url(url) + try: + response = requests.get(clean, timeout=10, headers=self._headers(), allow_redirects=True) + response.raise_for_status() + return parse_xiaohongshu_article_html(response.text, response.url or clean) + except ValueError: + raise + except Exception as exc: + raise ArticleFetchError(f"小红书笔记抓取失败:{exc}") from exc + + def search(self, keyword: str, limit: int = 20) -> list[ArticleContent]: + url = f"https://www.xiaohongshu.com/search_result?keyword={quote(keyword)}" + try: + response = requests.get(url, timeout=10, headers=self._headers()) + response.raise_for_status() + return parse_xiaohongshu_discovery_html(response.text, url, limit) + except Exception as exc: + raise ArticleFetchError(f"小红书关键字查询失败:{exc}") from exc + + def fetch_publisher(self, query: str, limit: int = 20) -> list[ArticleContent]: + url = clean_url(query) + if not url.startswith("http"): + url = f"https://www.xiaohongshu.com/user/profile/{quote(query)}" + try: + response = requests.get(url, timeout=10, headers=self._headers(), allow_redirects=True) + response.raise_for_status() + return parse_xiaohongshu_discovery_html(response.text, response.url or url, limit) + except Exception as exc: + raise ArticleFetchError(f"小红书发布者订阅刷新失败:{exc}") from exc diff --git a/backend/app/core/__init__.py b/backend/app/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backend/app/db/__init__.py b/backend/app/db/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backend/app/db/article_dao.py b/backend/app/db/article_dao.py new file mode 100644 index 0000000000000000000000000000000000000000..32261c869291dce4e5f2cb8ec6ca7d24e115bae0 --- /dev/null +++ b/backend/app/db/article_dao.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import hashlib +import json +from datetime import datetime + +from app.article_fetchers.base import ArticleContent +from app.db.engine import get_db +from app.db.models.articles import ArticleItem, ArticleSubscription, ArticleSubscriptionItem + + +def url_hash(url: str) -> str: + return hashlib.sha256(url.encode("utf-8")).hexdigest() + + +def _detach(obj): + data = {key: value for key, value in obj.__dict__.items() if not key.startswith("_")} + obj.__dict__.clear() + obj.__dict__.update(data) + return obj + + +def upsert_article_item(article: ArticleContent) -> ArticleItem: + db = next(get_db()) + try: + digest = url_hash(article.url) + item = None + if article.article_id: + item = ( + db.query(ArticleItem) + .filter_by(platform=article.platform, article_id=article.article_id) + .first() + ) + if item is None: + item = db.query(ArticleItem).filter_by(platform=article.platform, url_hash=digest).first() + if item is None: + item = ArticleItem( + platform=article.platform, + article_id=article.article_id, + url_hash=digest, + url=article.url, + title=article.title, + ) + db.add(item) + item.url = article.url + item.title = article.title + item.author_name = article.author_name + item.author_id = article.author_id + item.cover_url = article.cover_url + item.published_at = article.published_at + item.content_text = article.content_text + item.raw_metadata = json.dumps(article.raw_metadata or {}, ensure_ascii=False) + db.commit() + db.refresh(item) + return _detach(item) + finally: + db.close() + + +def get_article_item(item_id: int) -> ArticleItem | None: + db = next(get_db()) + try: + item = db.query(ArticleItem).filter_by(id=item_id).first() + return _detach(item) if item else None + finally: + db.close() + + +def list_article_items(subscription_id: int | None = None) -> list[ArticleItem]: + db = next(get_db()) + try: + query = db.query(ArticleItem) + if subscription_id is not None: + query = query.join( + ArticleSubscriptionItem, + ArticleSubscriptionItem.article_item_id == ArticleItem.id, + ).filter(ArticleSubscriptionItem.subscription_id == subscription_id) + return [_detach(item) for item in query.order_by(ArticleItem.id.desc()).all()] + finally: + db.close() + + +def mark_article_summarized(item_id: int, task_id: str) -> None: + db = next(get_db()) + try: + item = db.query(ArticleItem).filter_by(id=item_id).first() + if item: + item.summary_status = "summarized" + item.task_id = task_id + db.commit() + finally: + db.close() + + +def create_subscription( + platform: str, + subscription_type: str, + query: str, + label: str = "", +) -> ArticleSubscription: + db = next(get_db()) + try: + subscription = ArticleSubscription( + platform=platform, + type=subscription_type, + query=query, + label=label or query, + ) + db.add(subscription) + db.commit() + db.refresh(subscription) + return _detach(subscription) + finally: + db.close() + + +def list_subscriptions() -> list[ArticleSubscription]: + db = next(get_db()) + try: + return [ + _detach(item) + for item in db.query(ArticleSubscription).order_by(ArticleSubscription.id.desc()).all() + ] + finally: + db.close() + + +def get_subscription(subscription_id: int) -> ArticleSubscription | None: + db = next(get_db()) + try: + item = db.query(ArticleSubscription).filter_by(id=subscription_id).first() + return _detach(item) if item else None + finally: + db.close() + + +def update_subscription_refresh(subscription_id: int, error: str = "") -> None: + db = next(get_db()) + try: + item = db.query(ArticleSubscription).filter_by(id=subscription_id).first() + if item: + item.last_refresh_at = datetime.now() + item.last_error = error + db.commit() + finally: + db.close() + + +def link_subscription_item(subscription_id: int, article_item_id: int, match_reason: str) -> None: + db = next(get_db()) + try: + existing = ( + db.query(ArticleSubscriptionItem) + .filter_by(subscription_id=subscription_id, article_item_id=article_item_id) + .first() + ) + if existing is None: + db.add( + ArticleSubscriptionItem( + subscription_id=subscription_id, + article_item_id=article_item_id, + match_reason=match_reason, + ) + ) + db.commit() + finally: + db.close() diff --git a/backend/app/db/builtin_providers.json b/backend/app/db/builtin_providers.json new file mode 100644 index 0000000000000000000000000000000000000000..4e3385e2d59cfb521bd55bc8f1fca2874a8d04c3 --- /dev/null +++ b/backend/app/db/builtin_providers.json @@ -0,0 +1,65 @@ +[ + { + "id": "openai", + "name": "OpenAI", + "type": "built-in", + "logo": "OpenAI", + "api_key": "", + "base_url": "https://api.openai.com/v1", + "enabled": 0 + }, + { + "id": "deepseek", + "name": "DeepSeek", + "type": "built-in", + "logo": "DeepSeek", + "api_key": "", + "base_url": "https://api.deepseek.com", + "enabled": 1 + }, + { + "id": "qwen", + "name": "Qwen", + "type": "built-in", + "logo": "Qwen", + "api_key": "", + "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", + "enabled": 0 + }, + { + "id": "Claude", + "name": "Claude", + "type": "built-in", + "logo": "Claude", + "api_key": "", + "base_url": "https://", + "enabled": 0 + }, + { + "id": "gemini", + "name": "Gemini", + "type": "built-in", + "logo": "Gemini", + "api_key": "", + "base_url": "https://generativelanguage.googleapis.com/v1beta/openai/", + "enabled": 0 + }, + { + "id": "groq", + "name": "Groq", + "type": "built-in", + "logo": "Groq", + "api_key": "", + "base_url": "https://api.groq.com/openai/v1", + "enabled": 0 + }, + { + "id": "ollama", + "name": "ollama", + "type": "built-in", + "logo": "Ollama", + "api_key": "", + "base_url": "http://127.0.0.1:11434/v1", + "enabled": 0 + } +] diff --git a/backend/app/db/engine.py b/backend/app/db/engine.py new file mode 100644 index 0000000000000000000000000000000000000000..2331d5ea902449c05a7ea6eb04b21ade99928842 --- /dev/null +++ b/backend/app/db/engine.py @@ -0,0 +1,45 @@ +import os +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker, declarative_base +from dotenv import load_dotenv + +load_dotenv() + +# 默认 SQLite,如果想换 PostgreSQL 或 MySQL,可以直接改 .env +DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///video_memo.db") + +# SQLite 需要特定连接参数,其他数据库不需要 +engine_args = {} +if DATABASE_URL.startswith("sqlite"): + engine_args["connect_args"] = {"check_same_thread": False} + +_pool_args = {} +if not DATABASE_URL.startswith("sqlite"): + _pool_args = { + "pool_size": int(os.getenv("DB_POOL_SIZE", "10")), + "max_overflow": int(os.getenv("DB_MAX_OVERFLOW", "20")), + "pool_pre_ping": True, + } + +engine = create_engine( + DATABASE_URL, + echo=os.getenv("SQLALCHEMY_ECHO", "false").lower() == "true", + **engine_args, + **_pool_args, +) + +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +Base = declarative_base() + + +def get_engine(): + return engine + + +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() \ No newline at end of file diff --git a/backend/app/db/init_db.py b/backend/app/db/init_db.py new file mode 100644 index 0000000000000000000000000000000000000000..e1182cc6ca7dafd798489c1cdc2db733e837f6cc --- /dev/null +++ b/backend/app/db/init_db.py @@ -0,0 +1,34 @@ +from app.db.models.articles import ArticleItem, ArticleSubscription, ArticleSubscriptionItem +from app.db.models.models import Model +from app.db.models.providers import Provider +from app.db.models.trend_subscription import ( + NotificationChannel, + TrendSubscription, + TrendSubscriptionMatch, +) +from app.db.models.video_tasks import VideoTask +from app.db.engine import get_engine, Base +from sqlalchemy import inspect, text + +def init_db(): + engine = get_engine() + + Base.metadata.create_all(bind=engine) + _ensure_article_content_text(engine) + + +# 注:原 _ensure_model_columns 为 models.supports_multimodal 做的迁移已删除—— +# 该列在「drop multimodal」重构后已不再被 ORM 使用(纯遗留),且它的 +# `ALTER ... BOOLEAN NOT NULL DEFAULT 0` 在 Postgres 上会因 boolean 默认值类型不符直接报错。 +# 已有 SQLite 库里残留的该列无害,保持不动即可。 + + +def _ensure_article_content_text(engine): + inspector = inspect(engine) + if "article_items" not in inspector.get_table_names(): + return + columns = {column["name"] for column in inspector.get_columns("article_items")} + if "content_text" in columns: + return + with engine.begin() as conn: + conn.execute(text("ALTER TABLE article_items ADD COLUMN content_text TEXT NOT NULL DEFAULT ''")) diff --git a/backend/app/db/model_dao.py b/backend/app/db/model_dao.py new file mode 100644 index 0000000000000000000000000000000000000000..1111e679184b53d2401d363d746c7f30ad21e950 --- /dev/null +++ b/backend/app/db/model_dao.py @@ -0,0 +1,69 @@ +from app.db.engine import get_db +from app.db.models.models import Model +from app.db.models.providers import Provider + + +def get_model_by_provider_and_name(provider_id: int, model_name: str): + db = next(get_db()) + try: + model = db.query(Model).filter_by(provider_id=provider_id, model_name=model_name).first() + if model: + return { + "id": model.id, + "provider_id": model.provider_id, + "model_name": model.model_name, + "created_at": model.created_at, + } + return None + finally: + db.close() + + +def insert_model(provider_id: int, model_name: str): + db = next(get_db()) + try: + model = Model(provider_id=provider_id, model_name=model_name) + db.add(model) + db.commit() + db.refresh(model) + return { + "id": model.id, + "provider_id": model.provider_id, + "model_name": model.model_name, + "created_at": model.created_at, + } + finally: + db.close() + + +def get_models_by_provider(provider_id: int): + db = next(get_db()) + try: + models = db.query(Model).filter_by(provider_id=provider_id).all() + return [{"id": m.id, "model_name": m.model_name} for m in models] + finally: + db.close() + + +def delete_model(model_id: int): + db = next(get_db()) + try: + model = db.query(Model).filter_by(id=model_id).first() + if model: + db.delete(model) + db.commit() + finally: + db.close() + + +def get_all_models(): + db = next(get_db()) + try: + # 只查询启用状态供应商的模型 + models = db.query(Model).join(Provider, Model.provider_id == Provider.id).filter(Provider.enabled == 1).all() + return [ + {"id": m.id, "provider_id": m.provider_id, "model_name": m.model_name} + for m in models + ] + finally: + db.close() \ No newline at end of file diff --git a/backend/app/db/models/__init__.py b/backend/app/db/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backend/app/db/models/articles.py b/backend/app/db/models/articles.py new file mode 100644 index 0000000000000000000000000000000000000000..203526be833455a750d8293be1f7f2255f296ec4 --- /dev/null +++ b/backend/app/db/models/articles.py @@ -0,0 +1,55 @@ +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text, UniqueConstraint, func + +from app.db.engine import Base + + +class ArticleItem(Base): + __tablename__ = "article_items" + __table_args__ = ( + UniqueConstraint("platform", "article_id", name="uq_article_platform_article_id"), + UniqueConstraint("platform", "url_hash", name="uq_article_platform_url_hash"), + ) + + id = Column(Integer, primary_key=True, autoincrement=True) + platform = Column(String, nullable=False) + article_id = Column(String, nullable=False, default="") + url = Column(Text, nullable=False) + url_hash = Column(String, nullable=False) + title = Column(String, nullable=False) + author_name = Column(String, nullable=False, default="") + author_id = Column(String, nullable=False, default="") + summary_status = Column(String, nullable=False, default="pending") + task_id = Column(String, nullable=False, default="") + cover_url = Column(Text, nullable=False, default="") + published_at = Column(String, nullable=False, default="") + content_text = Column(Text, nullable=False, default="") + discovered_at = Column(DateTime, server_default=func.now()) + raw_metadata = Column(Text, nullable=False, default="{}") + + +class ArticleSubscription(Base): + __tablename__ = "article_subscriptions" + + id = Column(Integer, primary_key=True, autoincrement=True) + platform = Column(String, nullable=False) + type = Column(String, nullable=False) + query = Column(Text, nullable=False) + label = Column(String, nullable=False, default="") + enabled = Column(Boolean, nullable=False, default=True) + last_refresh_at = Column(DateTime, nullable=True) + last_error = Column(Text, nullable=False, default="") + created_at = Column(DateTime, server_default=func.now()) + updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now()) + + +class ArticleSubscriptionItem(Base): + __tablename__ = "article_subscription_items" + __table_args__ = ( + UniqueConstraint("subscription_id", "article_item_id", name="uq_subscription_article_item"), + ) + + id = Column(Integer, primary_key=True, autoincrement=True) + subscription_id = Column(Integer, ForeignKey("article_subscriptions.id"), nullable=False) + article_item_id = Column(Integer, ForeignKey("article_items.id"), nullable=False) + matched_at = Column(DateTime, server_default=func.now()) + match_reason = Column(Text, nullable=False, default="") diff --git a/backend/app/db/models/models.py b/backend/app/db/models/models.py new file mode 100644 index 0000000000000000000000000000000000000000..83cd50f45465174257622f92c926e024ce776d05 --- /dev/null +++ b/backend/app/db/models/models.py @@ -0,0 +1,12 @@ +from sqlalchemy import Column, Integer, String, DateTime, func, ForeignKey + +from app.db.engine import Base + + +class Model(Base): + __tablename__ = "models" + + id = Column(Integer, primary_key=True, autoincrement=True) + provider_id = Column(Integer, nullable=False) + model_name = Column(String, nullable=False) + created_at = Column(DateTime, server_default=func.now()) \ No newline at end of file diff --git a/backend/app/db/models/providers.py b/backend/app/db/models/providers.py new file mode 100644 index 0000000000000000000000000000000000000000..519d8bba4031babe793ddcdcda8cc51b0f9e4033 --- /dev/null +++ b/backend/app/db/models/providers.py @@ -0,0 +1,17 @@ +from sqlalchemy import Column, String, Integer, DateTime, func +from sqlalchemy.orm import declarative_base + +from app.db.engine import Base + + +class Provider(Base): + __tablename__ = "providers" + + id = Column(String, primary_key=True) + name = Column(String, nullable=False) + logo = Column(String, nullable=False) + type = Column(String, nullable=False) + api_key = Column(String, nullable=False) + base_url = Column(String, nullable=False) + enabled = Column(Integer, default=1) + created_at = Column(DateTime, server_default=func.now()) \ No newline at end of file diff --git a/backend/app/db/models/trend_subscription.py b/backend/app/db/models/trend_subscription.py new file mode 100644 index 0000000000000000000000000000000000000000..43724434d73ac707bdc7a5610a388cb396c7e255 --- /dev/null +++ b/backend/app/db/models/trend_subscription.py @@ -0,0 +1,50 @@ +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text, func + +from app.db.engine import Base + + +class TrendSubscription(Base): + __tablename__ = "trend_subscriptions" + + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String, nullable=False) + keywords = Column(Text, nullable=False, default="[]") # JSON array of keyword strings + platforms = Column(Text, nullable=False, default='["all"]') # JSON array of platform ids + match_mode = Column(String, nullable=False, default="any") # "any" | "all" + enabled = Column(Boolean, nullable=False, default=True) + push_enabled = Column(Boolean, nullable=False, default=False) + push_channel_ids = Column(Text, nullable=False, default="[]") # JSON array of channel ids + last_matched_at = Column(DateTime, nullable=True) + created_at = Column(DateTime, server_default=func.now()) + updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now()) + + +class TrendSubscriptionMatch(Base): + __tablename__ = "trend_subscription_matches" + + id = Column(Integer, primary_key=True, autoincrement=True) + subscription_id = Column(Integer, ForeignKey("trend_subscriptions.id"), nullable=False) + platform = Column(String, nullable=False) + item_id = Column(String, nullable=False) + title = Column(String, nullable=False) + url = Column(Text, nullable=False, default="") + hot_score = Column(String, nullable=False, default="") + matched_keywords = Column(Text, nullable=False, default="[]") # JSON array of matched keywords + matched_at = Column(DateTime, server_default=func.now()) + is_read = Column(Boolean, nullable=False, default=False) + # dedup: same subscription + same platform + same item_id + __table_args__ = ( + {"sqlite_autoincrement": True}, + ) + + +class NotificationChannel(Base): + __tablename__ = "notification_channels" + + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String, nullable=False) + type = Column(String, nullable=False) # "webhook" | "bark" | "email" + config = Column(Text, nullable=False, default="{}") # JSON object, type-specific + enabled = Column(Boolean, nullable=False, default=True) + created_at = Column(DateTime, server_default=func.now()) + updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now()) diff --git a/backend/app/db/models/video_tasks.py b/backend/app/db/models/video_tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..b3a1b86c58a3169ab97abeb642e71397845aca3a --- /dev/null +++ b/backend/app/db/models/video_tasks.py @@ -0,0 +1,14 @@ +from sqlalchemy import Column, Integer, String, DateTime, func +from sqlalchemy.orm import declarative_base + +from app.db.engine import Base + + +class VideoTask(Base): + __tablename__ = "video_tasks" + + id = Column(Integer, primary_key=True, autoincrement=True) + video_id = Column(String, nullable=False) + platform = Column(String, nullable=False) + task_id = Column(String, unique=True, nullable=False) + created_at = Column(DateTime, server_default=func.now()) \ No newline at end of file diff --git a/backend/app/db/provider_dao.py b/backend/app/db/provider_dao.py new file mode 100644 index 0000000000000000000000000000000000000000..4d2d875af9eacc1c25fb5a9d405ac2079429bcf4 --- /dev/null +++ b/backend/app/db/provider_dao.py @@ -0,0 +1,129 @@ +import json +import os +import sys +from app.db.models.providers import Provider +from app.utils.logger import get_logger +from app.db.engine import get_engine, Base, get_db + +logger = get_logger(__name__) + + +def get_builtin_providers_path(): + if getattr(sys, 'frozen', False): + base_path = sys._MEIPASS + else: + base_path = os.path.dirname(__file__) + return os.path.join(base_path, 'builtin_providers.json') + + +def seed_default_providers(): + db = next(get_db()) + try: + if db.query(Provider).count() > 0: + logger.info("Providers already exist, skipping seed.") + return + + json_path = get_builtin_providers_path() + try: + with open(json_path, 'r', encoding='utf-8') as f: + providers = json.load(f) + except Exception as e: + logger.error(f"Failed to read builtin_providers.json: {e}") + return + + for p in providers: + db.add(Provider( + id=p['id'], + name=p['name'], + api_key=p['api_key'], + base_url=p['base_url'], + logo=p['logo'], + type=p['type'], + enabled=p.get('enabled', 1) + )) + db.commit() + logger.info("Default providers seeded successfully.") + except Exception as e: + logger.error(f"Failed to seed default providers: {e}") + finally: + db.close() + + +def insert_provider(id: str, name: str, api_key: str, base_url: str, logo: str, type_: str, enabled: int = 1): + db = next(get_db()) + try: + provider = Provider(id=id, name=name, api_key=api_key, base_url=base_url, logo=logo, type=type_, enabled=enabled) + db.add(provider) + db.commit() + logger.info(f"Provider inserted successfully. id: {id}, name: {name}, type: {type_}") + return id + except Exception as e: + logger.error(f"Failed to insert provider: {e}") + finally: + db.close() + + +def get_enabled_providers(): + db = next(get_db()) + try: + return db.query(Provider).filter_by(enabled=1).all() + finally: + db.close() + + +def get_provider_by_name(name: str): + db = next(get_db()) + try: + return db.query(Provider).filter_by(name=name).first() + finally: + db.close() + + +def get_provider_by_id(id: str): + db = next(get_db()) + try: + return db.query(Provider).filter_by(id=id).first() + finally: + db.close() + + +def get_all_providers(): + db = next(get_db()) + try: + return db.query(Provider).all() + finally: + db.close() + + +def update_provider(id: str, **kwargs): + db = next(get_db()) + try: + provider = db.query(Provider).filter_by(id=id).first() + if not provider: + logger.warning(f"Provider {id} not found for update.") + return + + for key, value in kwargs.items(): + if hasattr(provider, key): + setattr(provider, key, value) + + db.commit() + logger.info(f"Provider updated successfully. id: {id}, updated_fields: {list(kwargs.keys())}") + except Exception as e: + logger.error(f"Failed to update provider: {e}") + finally: + db.close() + + +def delete_provider(id: str): + db = next(get_db()) + try: + provider = db.query(Provider).filter_by(id=id).first() + if provider: + db.delete(provider) + db.commit() + logger.info(f"Provider deleted successfully. id: {id}") + except Exception as e: + logger.error(f"Failed to delete provider: {e}") + finally: + db.close() \ No newline at end of file diff --git a/backend/app/db/sqlite_client.py b/backend/app/db/sqlite_client.py new file mode 100644 index 0000000000000000000000000000000000000000..a6546f4088e9127bd3bcef8319c60c67a1efedb1 --- /dev/null +++ b/backend/app/db/sqlite_client.py @@ -0,0 +1,4 @@ +import sqlite3 + +def get_connection(): + return sqlite3.connect("video_memo.db") diff --git a/backend/app/db/trend_subscription_dao.py b/backend/app/db/trend_subscription_dao.py new file mode 100644 index 0000000000000000000000000000000000000000..f923466247b08bf1fed792a256ee53beec434d2f --- /dev/null +++ b/backend/app/db/trend_subscription_dao.py @@ -0,0 +1,293 @@ +from __future__ import annotations + +import json +from datetime import datetime + +from app.db.engine import get_db +from app.db.models.trend_subscription import ( + NotificationChannel, + TrendSubscription, + TrendSubscriptionMatch, +) + + +def _detach(obj): + data = {key: value for key, value in obj.__dict__.items() if not key.startswith("_")} + obj.__dict__.clear() + obj.__dict__.update(data) + return obj + + +# ─── Trend Subscriptions ────────────────────────────────────────────────────────── + +def create_subscription( + name: str, + keywords: list[str], + platforms: list[str] | None = None, + match_mode: str = "any", + push_enabled: bool = False, + push_channel_ids: list[int] | None = None, +) -> TrendSubscription: + db = next(get_db()) + try: + sub = TrendSubscription( + name=name, + keywords=json.dumps(keywords, ensure_ascii=False), + platforms=json.dumps(platforms or ["all"], ensure_ascii=False), + match_mode=match_mode, + push_enabled=push_enabled, + push_channel_ids=json.dumps(push_channel_ids or []), + ) + db.add(sub) + db.commit() + db.refresh(sub) + return _detach(sub) + finally: + db.close() + + +def list_subscriptions() -> list[TrendSubscription]: + db = next(get_db()) + try: + return [ + _detach(item) + for item in db.query(TrendSubscription).order_by(TrendSubscription.id.desc()).all() + ] + finally: + db.close() + + +def get_subscription(subscription_id: int) -> TrendSubscription | None: + db = next(get_db()) + try: + item = db.query(TrendSubscription).filter_by(id=subscription_id).first() + return _detach(item) if item else None + finally: + db.close() + + +def update_subscription( + subscription_id: int, + name: str | None = None, + keywords: list[str] | None = None, + platforms: list[str] | None = None, + match_mode: str | None = None, + enabled: bool | None = None, + push_enabled: bool | None = None, + push_channel_ids: list[int] | None = None, +) -> TrendSubscription | None: + db = next(get_db()) + try: + sub = db.query(TrendSubscription).filter_by(id=subscription_id).first() + if sub is None: + return None + if name is not None: + sub.name = name + if keywords is not None: + sub.keywords = json.dumps(keywords, ensure_ascii=False) + if platforms is not None: + sub.platforms = json.dumps(platforms, ensure_ascii=False) + if match_mode is not None: + sub.match_mode = match_mode + if enabled is not None: + sub.enabled = enabled + if push_enabled is not None: + sub.push_enabled = push_enabled + if push_channel_ids is not None: + sub.push_channel_ids = json.dumps(push_channel_ids) + db.commit() + db.refresh(sub) + return _detach(sub) + finally: + db.close() + + +def delete_subscription(subscription_id: int) -> bool: + db = next(get_db()) + try: + sub = db.query(TrendSubscription).filter_by(id=subscription_id).first() + if sub is None: + return False + # also delete associated matches + db.query(TrendSubscriptionMatch).filter_by(subscription_id=subscription_id).delete() + db.delete(sub) + db.commit() + return True + finally: + db.close() + + +def update_subscription_refresh(subscription_id: int) -> None: + db = next(get_db()) + try: + sub = db.query(TrendSubscription).filter_by(id=subscription_id).first() + if sub: + sub.last_matched_at = datetime.now() + db.commit() + finally: + db.close() + + +# ─── Trend Subscription Matches ─────────────────────────────────────────────────── + +def create_match( + subscription_id: int, + platform: str, + item_id: str, + title: str, + url: str = "", + hot_score: str = "", + matched_keywords: list[str] | None = None, +) -> TrendSubscriptionMatch | None: + """Create a match record. Returns None if this (subscription, platform, item_id) already exists.""" + db = next(get_db()) + try: + existing = ( + db.query(TrendSubscriptionMatch) + .filter_by(subscription_id=subscription_id, platform=platform, item_id=item_id) + .first() + ) + if existing is not None: + return None # already matched before + match = TrendSubscriptionMatch( + subscription_id=subscription_id, + platform=platform, + item_id=item_id, + title=title, + url=url, + hot_score=hot_score, + matched_keywords=json.dumps(matched_keywords or [], ensure_ascii=False), + ) + db.add(match) + db.commit() + db.refresh(match) + return _detach(match) + finally: + db.close() + + +def list_matches( + subscription_id: int | None = None, + limit: int = 100, + unread_only: bool = False, +) -> list[TrendSubscriptionMatch]: + db = next(get_db()) + try: + query = db.query(TrendSubscriptionMatch) + if subscription_id is not None: + query = query.filter_by(subscription_id=subscription_id) + if unread_only: + query = query.filter_by(is_read=False) + return [ + _detach(item) + for item in query.order_by(TrendSubscriptionMatch.matched_at.desc()) + .limit(limit) + .all() + ] + finally: + db.close() + + +def mark_matches_read(subscription_id: int) -> int: + """Mark all matches for a subscription as read. Returns count of updated rows.""" + db = next(get_db()) + try: + count = ( + db.query(TrendSubscriptionMatch) + .filter_by(subscription_id=subscription_id, is_read=False) + .update({"is_read": True}) + ) + db.commit() + return count + finally: + db.close() + + +def count_unread_matches(subscription_id: int) -> int: + db = next(get_db()) + try: + return ( + db.query(TrendSubscriptionMatch) + .filter_by(subscription_id=subscription_id, is_read=False) + .count() + ) + finally: + db.close() + + +# ─── Notification Channels ──────────────────────────────────────────────────────── + +def create_channel(name: str, channel_type: str, config: dict | None = None) -> NotificationChannel: + db = next(get_db()) + try: + channel = NotificationChannel( + name=name, + type=channel_type, + config=json.dumps(config or {}, ensure_ascii=False), + ) + db.add(channel) + db.commit() + db.refresh(channel) + return _detach(channel) + finally: + db.close() + + +def list_channels() -> list[NotificationChannel]: + db = next(get_db()) + try: + return [ + _detach(item) + for item in db.query(NotificationChannel).order_by(NotificationChannel.id.desc()).all() + ] + finally: + db.close() + + +def get_channel(channel_id: int) -> NotificationChannel | None: + db = next(get_db()) + try: + item = db.query(NotificationChannel).filter_by(id=channel_id).first() + return _detach(item) if item else None + finally: + db.close() + + +def update_channel( + channel_id: int, + name: str | None = None, + channel_type: str | None = None, + config: dict | None = None, + enabled: bool | None = None, +) -> NotificationChannel | None: + db = next(get_db()) + try: + channel = db.query(NotificationChannel).filter_by(id=channel_id).first() + if channel is None: + return None + if name is not None: + channel.name = name + if channel_type is not None: + channel.type = channel_type + if config is not None: + channel.config = json.dumps(config, ensure_ascii=False) + if enabled is not None: + channel.enabled = enabled + db.commit() + db.refresh(channel) + return _detach(channel) + finally: + db.close() + + +def delete_channel(channel_id: int) -> bool: + db = next(get_db()) + try: + channel = db.query(NotificationChannel).filter_by(id=channel_id).first() + if channel is None: + return False + db.delete(channel) + db.commit() + return True + finally: + db.close() diff --git a/backend/app/db/video_task_dao.py b/backend/app/db/video_task_dao.py new file mode 100644 index 0000000000000000000000000000000000000000..42cd4359268745870f7d880e793bded4bdfdd43a --- /dev/null +++ b/backend/app/db/video_task_dao.py @@ -0,0 +1,61 @@ +from app.db.models.video_tasks import VideoTask +from app.db.engine import get_db +from app.utils.logger import get_logger + +logger = get_logger(__name__) + + +# 插入任务 +def insert_video_task(video_id: str, platform: str, task_id: str): + db = next(get_db()) + try: + task = VideoTask(video_id=video_id, platform=platform, task_id=task_id) + db.add(task) + db.commit() + db.refresh(task) + logger.info(f"Video task inserted successfully. video_id: {video_id}, platform: {platform}, task_id: {task_id}") + except Exception as e: + logger.error(f"Failed to insert video task: {e}") + finally: + db.close() + + +# 查询任务(最新一条) +def get_task_by_video(video_id: str, platform: str): + db = next(get_db()) + try: + task = ( + db.query(VideoTask) + .filter_by(video_id=video_id, platform=platform) + .order_by(VideoTask.created_at.desc()) + .first() + ) + if task: + logger.info(f"Task found for video_id: {video_id} and platform: {platform}") + return task.task_id + else: + logger.info(f"No task found for video_id: {video_id} and platform: {platform}") + return None + except Exception as e: + logger.error(f"Failed to get task by video: {e}") + finally: + db.close() + + +# 删除任务 +def delete_task_by_video(video_id: str, platform: str): + db = next(get_db()) + try: + tasks = ( + db.query(VideoTask) + .filter_by(video_id=video_id, platform=platform) + .all() + ) + for task in tasks: + db.delete(task) + db.commit() + logger.info(f"Task(s) deleted for video_id: {video_id} and platform: {platform}") + except Exception as e: + logger.error(f"Failed to delete task by video: {e}") + finally: + db.close() \ No newline at end of file diff --git a/backend/app/decorators/__init__.py b/backend/app/decorators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backend/app/decorators/timeit.py b/backend/app/decorators/timeit.py new file mode 100644 index 0000000000000000000000000000000000000000..c7ede42c8df91345ac139b67a97af3bfacc046ef --- /dev/null +++ b/backend/app/decorators/timeit.py @@ -0,0 +1,13 @@ +import time +import functools + +def timeit(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + start = time.perf_counter() + result = func(*args, **kwargs) + end = time.perf_counter() + duration = end - start + print(f"{func.__name__} executed in {duration:.4f} seconds") + return result + return wrapper diff --git a/backend/app/downloaders/__init__.py b/backend/app/downloaders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backend/app/downloaders/base.py b/backend/app/downloaders/base.py new file mode 100644 index 0000000000000000000000000000000000000000..d1e71ad0cf058afa38824afa617b1bf0a401de65 --- /dev/null +++ b/backend/app/downloaders/base.py @@ -0,0 +1,52 @@ +import enum + +from abc import ABC, abstractmethod +from typing import Optional, Union + +from app.enmus.note_enums import DownloadQuality +from app.models.notes_model import AudioDownloadResult +from app.models.transcriber_model import TranscriptResult +from os import getenv +QUALITY_MAP = { + "fast": "32", + "medium": "64", + "slow": "128" +} + + +class Downloader(ABC): + def __init__(self): + #TODO 需要修改为可配置 + self.quality = QUALITY_MAP.get('fast') + self.cache_data=getenv('DATA_DIR') + + @abstractmethod + def download(self, video_url: str, output_dir: str = None, + quality: DownloadQuality = "fast", need_video: Optional[bool] = False, + skip_download: bool = False) -> AudioDownloadResult: + ''' + + :param need_video: + :param video_url: 资源链接 + :param output_dir: 输出路径 默认根目录data + :param quality: 音频质量 fast | medium | slow + :return:返回一个 AudioDownloadResult 类 + ''' + pass + + @staticmethod + def download_video(self, video_url: str, + output_dir: Union[str, None] = None) -> str: + pass + + def download_subtitles(self, video_url: str, output_dir: str = None, + langs: list = None) -> Optional[TranscriptResult]: + ''' + 尝试获取平台字幕(人工字幕或自动生成字幕) + + :param video_url: 视频链接 + :param output_dir: 输出路径 + :param langs: 优先语言列表,如 ['zh-Hans', 'zh', 'en'] + :return: TranscriptResult 或 None(无字幕时) + ''' + return None diff --git a/backend/app/downloaders/bilibili_downloader.py b/backend/app/downloaders/bilibili_downloader.py new file mode 100644 index 0000000000000000000000000000000000000000..0a94849f69478929ba14b8e892ff8b9636919a8b --- /dev/null +++ b/backend/app/downloaders/bilibili_downloader.py @@ -0,0 +1,343 @@ +import os +import json +import logging +import tempfile +from abc import ABC +from typing import Union, Optional, List + +import yt_dlp + +from app.downloaders.base import Downloader, DownloadQuality, QUALITY_MAP +from app.downloaders.bilibili_subtitle import BilibiliSubtitleFetcher +from app.models.notes_model import AudioDownloadResult +from app.models.transcriber_model import TranscriptResult, TranscriptSegment +from app.utils.path_helper import get_data_dir +from app.utils.url_parser import extract_video_id +from app.services.cookie_manager import CookieConfigManager + +logger = logging.getLogger(__name__) + + +class BilibiliDownloader(Downloader, ABC): + def __init__(self): + super().__init__() + self._cookie_mgr = CookieConfigManager() + self._cookie = self._cookie_mgr.get('bilibili') + self._cookiefile = self._write_netscape_cookie_file() + + def _write_netscape_cookie_file(self) -> Optional[str]: + """将 Cookie 写入 Netscape 格式临时文件,返回文件路径(供 yt-dlp cookiefile 使用)""" + if not self._cookie: + logger.warning("B站 Cookie 未配置,下载可能失败") + return None + lines = ["# Netscape HTTP Cookie File\n"] + for pair in self._cookie.split("; "): + if "=" in pair: + key, value = pair.split("=", 1) + lines.append(f".bilibili.com\tTRUE\t/\tFALSE\t0\t{key}\t{value}\n") + tmp = tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False, encoding='utf-8') + tmp.writelines(lines) + tmp.close() + logger.info("已生成 B站 Netscape Cookie 文件: %s (条目: %d)", tmp.name, len(lines) - 1) + return tmp.name + + def download( + self, + video_url: str, + output_dir: Union[str, None] = None, + quality: DownloadQuality = "fast", + need_video:Optional[bool]=False + ) -> AudioDownloadResult: + if output_dir is None: + output_dir = get_data_dir() + if not output_dir: + output_dir=self.cache_data + os.makedirs(output_dir, exist_ok=True) + + output_path = os.path.join(output_dir, "%(id)s.%(ext)s") + + ydl_opts = { + 'format': 'bestaudio[ext=m4a]/bestaudio/best', + 'outtmpl': output_path, + 'http_headers': {'Referer': 'https://www.bilibili.com'}, + 'postprocessors': [ + { + 'key': 'FFmpegExtractAudio', + 'preferredcodec': 'mp3', + 'preferredquality': '64', + } + ], + 'noplaylist': True, + 'quiet': False, + } + if self._cookiefile: + ydl_opts['cookiefile'] = self._cookiefile + + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + info = ydl.extract_info(video_url, download=True) + video_id = info.get("id") + title = info.get("title") + duration = info.get("duration", 0) + cover_url = info.get("thumbnail") + audio_path = os.path.join(output_dir, f"{video_id}.mp3") + + return AudioDownloadResult( + file_path=audio_path, + title=title, + duration=duration, + cover_url=cover_url, + platform="bilibili", + video_id=video_id, + raw_info=info, + video_path=None # ❗音频下载不包含视频路径 + ) + + def download_video( + self, + video_url: str, + output_dir: Union[str, None] = None, + ) -> str: + """ + 下载视频,返回视频文件路径 + """ + + if output_dir is None: + output_dir = get_data_dir() + os.makedirs(output_dir, exist_ok=True) + print("video_url",video_url) + video_id=extract_video_id(video_url, "bilibili") + video_path = os.path.join(output_dir, f"{video_id}.mp4") + if os.path.exists(video_path): + return video_path + + # 检查是否已经存在 + + + output_path = os.path.join(output_dir, "%(id)s.%(ext)s") + + ydl_opts = { + 'format': 'bv*[ext=mp4]/bestvideo+bestaudio/best', + 'outtmpl': output_path, + 'http_headers': {'Referer': 'https://www.bilibili.com'}, + 'noplaylist': True, + 'quiet': False, + 'merge_output_format': 'mp4', # 确保合并成 mp4 + } + if self._cookiefile: + ydl_opts['cookiefile'] = self._cookiefile + + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + info = ydl.extract_info(video_url, download=True) + video_id = info.get("id") + video_path = os.path.join(output_dir, f"{video_id}.mp4") + + if not os.path.exists(video_path): + raise FileNotFoundError(f"视频文件未找到: {video_path}") + + return video_path + + def delete_video(self, video_path: str) -> str: + """ + 删除视频文件 + """ + if os.path.exists(video_path): + os.remove(video_path) + return f"视频文件已删除: {video_path}" + else: + return f"视频文件未找到: {video_path}" + + def download_subtitles(self, video_url: str, output_dir: str = None, + langs: List[str] = None) -> Optional[TranscriptResult]: + """ + 尝试获取B站视频字幕 + + :param video_url: 视频链接 + :param output_dir: 输出路径 + :param langs: 优先语言列表 + :return: TranscriptResult 或 None + """ + # 1) 优先走 B 站官方 player API(直拉,无需下视频;AI 字幕需 SESSDATA cookie) + try: + result = BilibiliSubtitleFetcher().fetch_subtitles(video_url) + if result and result.segments: + return result + except Exception as e: + logger.warning(f"player API 直拉字幕异常,回退到 yt-dlp: {e}") + + # 2) Fallback:原 yt-dlp 路径(更脆弱,遇到签名/Cookie 问题失败概率较高) + if output_dir is None: + output_dir = get_data_dir() + if not output_dir: + output_dir = self.cache_data + os.makedirs(output_dir, exist_ok=True) + + if langs is None: + langs = ['zh-Hans', 'zh', 'zh-CN', 'ai-zh', 'en', 'en-US'] + + video_id = extract_video_id(video_url, "bilibili") + + ydl_opts = { + 'writesubtitles': True, + 'writeautomaticsub': True, + 'subtitleslangs': langs, + 'subtitlesformat': 'srt/json3/best', # 支持多种格式 + 'skip_download': True, + 'outtmpl': os.path.join(output_dir, f'{video_id}.%(ext)s'), + 'quiet': True, + } + + # 通过 CookieConfigManager 注入 B站 Cookie(Netscape cookiefile) + if self._cookiefile: + ydl_opts['cookiefile'] = self._cookiefile + ydl_opts['http_headers'] = {'Referer': 'https://www.bilibili.com'} + + try: + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + info = ydl.extract_info(video_url, download=True) + + # 查找下载的字幕文件 + subtitles = info.get('requested_subtitles') or {} + if not subtitles: + logger.info(f"B站视频 {video_id} 没有可用字幕") + return None + + # 按优先级查找字幕 + detected_lang = None + sub_info = None + for lang in langs: + if lang in subtitles: + detected_lang = lang + sub_info = subtitles[lang] + break + + # 如果按优先级没找到,取第一个可用的(排除弹幕) + if not detected_lang: + for lang, info_item in subtitles.items(): + if lang != 'danmaku': # 排除弹幕 + detected_lang = lang + sub_info = info_item + break + + if not sub_info: + logger.info(f"B站视频 {video_id} 没有可用字幕(排除弹幕)") + return None + + # 检查是否有内嵌数据(yt-dlp 有时直接返回字幕内容) + if 'data' in sub_info and sub_info['data']: + logger.info(f"直接从返回数据解析字幕: {detected_lang}") + return self._parse_srt_content(sub_info['data'], detected_lang) + + # 查找字幕文件 + ext = sub_info.get('ext', 'srt') + subtitle_file = os.path.join(output_dir, f"{video_id}.{detected_lang}.{ext}") + + if not os.path.exists(subtitle_file): + logger.info(f"字幕文件不存在: {subtitle_file}") + return None + + # 根据格式解析字幕文件 + if ext == 'json3': + return self._parse_json3_subtitle(subtitle_file, detected_lang) + else: + with open(subtitle_file, 'r', encoding='utf-8') as f: + return self._parse_srt_content(f.read(), detected_lang) + + except Exception as e: + logger.warning(f"获取B站字幕失败: {e}") + return None + + def _parse_srt_content(self, srt_content: str, language: str) -> Optional[TranscriptResult]: + """ + 解析 SRT 格式字幕内容 + + :param srt_content: SRT 字幕文本内容 + :param language: 语言代码 + :return: TranscriptResult + """ + import re + try: + segments = [] + # SRT 格式: 序号\n时间戳\n文本\n\n + pattern = r'(\d+)\n(\d{2}:\d{2}:\d{2},\d{3})\s*-->\s*(\d{2}:\d{2}:\d{2},\d{3})\n(.*?)(?=\n\n|\n\d+\n|$)' + matches = re.findall(pattern, srt_content, re.DOTALL) + + for match in matches: + idx, start_time, end_time, text = match + text = text.strip() + if not text: + continue + + # 转换时间格式 00:00:00,000 -> 秒 + def time_to_seconds(t): + parts = t.replace(',', '.').split(':') + return float(parts[0]) * 3600 + float(parts[1]) * 60 + float(parts[2]) + + segments.append(TranscriptSegment( + start=time_to_seconds(start_time), + end=time_to_seconds(end_time), + text=text + )) + + if not segments: + return None + + full_text = ' '.join(seg.text for seg in segments) + logger.info(f"成功解析B站SRT字幕,共 {len(segments)} 段") + return TranscriptResult( + language=language, + full_text=full_text, + segments=segments, + raw={'source': 'bilibili_subtitle', 'format': 'srt'} + ) + + except Exception as e: + logger.warning(f"解析SRT字幕失败: {e}") + return None + + def _parse_json3_subtitle(self, subtitle_file: str, language: str) -> Optional[TranscriptResult]: + """ + 解析 json3 格式字幕文件 + + :param subtitle_file: 字幕文件路径 + :param language: 语言代码 + :return: TranscriptResult + """ + try: + with open(subtitle_file, 'r', encoding='utf-8') as f: + data = json.load(f) + + segments = [] + events = data.get('events', []) + + for event in events: + # json3 格式中时间单位是毫秒 + start_ms = event.get('tStartMs', 0) + duration_ms = event.get('dDurationMs', 0) + + # 提取文本 + segs = event.get('segs', []) + text = ''.join(seg.get('utf8', '') for seg in segs).strip() + + if text: # 只添加非空文本 + segments.append(TranscriptSegment( + start=start_ms / 1000.0, + end=(start_ms + duration_ms) / 1000.0, + text=text + )) + + if not segments: + return None + + full_text = ' '.join(seg.text for seg in segments) + + logger.info(f"成功解析B站字幕,共 {len(segments)} 段") + return TranscriptResult( + language=language, + full_text=full_text, + segments=segments, + raw={'source': 'bilibili_subtitle', 'file': subtitle_file} + ) + + except Exception as e: + logger.warning(f"解析字幕文件失败: {e}") + return None \ No newline at end of file diff --git a/backend/app/downloaders/bilibili_subtitle.py b/backend/app/downloaders/bilibili_subtitle.py new file mode 100644 index 0000000000000000000000000000000000000000..9f3790e155b617b9edc86bf4568935a63f067c58 --- /dev/null +++ b/backend/app/downloaders/bilibili_subtitle.py @@ -0,0 +1,164 @@ +""" +直接调用 B 站 player API 拿字幕,绕过 yt-dlp。 + +流程: +1. 从 URL 提 BV id(已有 utils.url_parser.extract_video_id) +2. GET /x/web-interface/view?bvid=BVxxx → 拿 cid +3. GET /x/player/wbi/v2?bvid=...&cid=... → 返回 data.subtitle.subtitles[] + 每条带 subtitle_url(B 站后端已经签好 auth_key 的完整地址) +4. 按优先级(人工 zh-CN > AI zh-CN > 任意 zh > 任意非空)选一条 +5. fetch subtitle_url → JSON {body:[{from,to,content,...}]} +6. 解析为 TranscriptResult + +AI 字幕需要登录态 cookie(SESSDATA);通过 CookieConfigManager 注入。 +""" + +from typing import List, Optional + +import requests + +from app.models.transcriber_model import TranscriptResult, TranscriptSegment +from app.services.cookie_manager import CookieConfigManager +from app.utils.logger import get_logger +from app.utils.url_parser import extract_video_id + +logger = get_logger(__name__) + +UA = ( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " + "AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36" +) + + +class BilibiliSubtitleFetcher: + """通过 B 站官方 API 直拉字幕。""" + + def __init__(self): + self._cookie = CookieConfigManager().get("bilibili") or "" + + def _headers(self) -> dict: + h = { + "User-Agent": UA, + "Referer": "https://www.bilibili.com", + } + if self._cookie: + h["Cookie"] = self._cookie + return h + + def _get_cid(self, bvid: str) -> Optional[int]: + url = "https://api.bilibili.com/x/web-interface/view" + try: + resp = requests.get(url, params={"bvid": bvid}, headers=self._headers(), timeout=10) + data = resp.json() + except Exception as e: + logger.warning(f"获取 cid 失败: {e}") + return None + if data.get("code") != 0: + logger.warning(f"view API 返回错误: code={data.get('code')}, msg={data.get('message')}") + return None + cid = data.get("data", {}).get("cid") + return int(cid) if cid else None + + def _list_subtitles(self, bvid: str, cid: int) -> List[dict]: + url = "https://api.bilibili.com/x/player/wbi/v2" + try: + resp = requests.get(url, params={"bvid": bvid, "cid": cid}, headers=self._headers(), timeout=10) + data = resp.json() + except Exception as e: + logger.warning(f"获取字幕列表失败: {e}") + return [] + if data.get("code") != 0: + logger.warning(f"player API 返回错误: code={data.get('code')}, msg={data.get('message')}") + return [] + subtitles = data.get("data", {}).get("subtitle", {}).get("subtitles", []) + return subtitles or [] + + def _pick(self, subtitles: List[dict]) -> Optional[dict]: + """优先级:人工中文 > AI 中文 > 任意中文 > 任意非空。""" + if not subtitles: + return None + + def is_zh(s: dict) -> bool: + lan = (s.get("lan") or "").lower() + return lan.startswith("zh") or lan == "ai-zh" + + # 人工中文(type 0=AI, 1=人工 ;ai_type=0 视为人工) + for s in subtitles: + if is_zh(s) and not s.get("ai_type"): + return s + # AI 中文 + for s in subtitles: + if is_zh(s): + return s + # 任意非空 + return subtitles[0] + + @staticmethod + def _normalize_url(url: str) -> str: + if url.startswith("//"): + return "https:" + url + return url + + def _fetch_body(self, subtitle_url: str) -> Optional[List[dict]]: + try: + resp = requests.get(self._normalize_url(subtitle_url), headers=self._headers(), timeout=15) + data = resp.json() + return data.get("body") or [] + except Exception as e: + logger.warning(f"下载字幕 JSON 失败: {e}") + return None + + def fetch_subtitles(self, video_url: str) -> Optional[TranscriptResult]: + bvid = extract_video_id(video_url, "bilibili") + if not bvid: + logger.info("无法从 URL 提取 BV id") + return None + + cid = self._get_cid(bvid) + if not cid: + logger.info(f"{bvid} 没有取到 cid") + return None + + subtitles = self._list_subtitles(bvid, cid) + if not subtitles: + logger.info(f"{bvid} (cid={cid}) 没有可用字幕轨") + return None + + track = self._pick(subtitles) + if not track or not track.get("subtitle_url"): + logger.info(f"{bvid} 字幕轨存在但没有 subtitle_url(可能未登录、需要 SESSDATA cookie)") + return None + + lan = track.get("lan") or "zh" + body = self._fetch_body(track["subtitle_url"]) + if not body: + return None + + segments: List[TranscriptSegment] = [] + for item in body: + text = (item.get("content") or "").strip() + if not text: + continue + segments.append(TranscriptSegment( + start=float(item.get("from", 0)), + end=float(item.get("to", 0)), + text=text, + )) + + if not segments: + return None + + full_text = " ".join(s.text for s in segments) + logger.info(f"B站直拉字幕成功: {bvid} lan={lan} 共 {len(segments)} 段") + return TranscriptResult( + language=lan, + full_text=full_text, + segments=segments, + raw={ + "source": "bilibili_player_api", + "bvid": bvid, + "cid": cid, + "lan": lan, + "ai_type": track.get("ai_type"), + }, + ) diff --git a/backend/app/downloaders/common.py b/backend/app/downloaders/common.py new file mode 100644 index 0000000000000000000000000000000000000000..c71342f392c82126bbad96581fea404a5945324d --- /dev/null +++ b/backend/app/downloaders/common.py @@ -0,0 +1 @@ +# def download(): diff --git a/backend/app/downloaders/douyin_downloader.py b/backend/app/downloaders/douyin_downloader.py new file mode 100644 index 0000000000000000000000000000000000000000..5f31719984f65c815379ae17af9cd890c6a7a7f4 --- /dev/null +++ b/backend/app/downloaders/douyin_downloader.py @@ -0,0 +1,499 @@ +from __future__ import annotations + +import json +import os +import re +import subprocess +from dataclasses import dataclass, field +from typing import Any, Literal, Optional, Union +from urllib.parse import parse_qs, unquote, urlparse + +import requests + +from app.downloaders.base import Downloader +from app.enmus.note_enums import DownloadQuality +from app.models.audio_model import AudioDownloadResult +from app.models.transcriber_model import TranscriptResult, TranscriptSegment +from app.utils.path_helper import get_data_dir + + +SHARE_PAGE_UA = ( + "Mozilla/5.0 (iPhone; CPU iPhone OS 17_0 like Mac OS X) " + "AppleWebKit/605.1.15 (KHTML, like Gecko) " + "Version/17.0 Mobile/15E148 Safari/604.1" +) + +ROUTER_DATA_RE = re.compile(r"window\._ROUTER_DATA\s*=\s*(\{.+)", re.DOTALL) +RENDER_DATA_RE = re.compile( + r'' +) +DOUYIN_URL_RE = re.compile( + r"https?://(?:v\.douyin\.com|www\.douyin\.com|www\.iesdouyin\.com|m\.douyin\.com)[^\s\]]*" +) +IMAGE_AWEME_TYPES = {2, 68} + + +class DouyinResolveError(Exception): + pass + + +@dataclass +class DouyinContentMeta: + aweme_id: str + title: str + author: str + source_url: str + content_type: Literal["video", "image"] = "video" + aweme_type: Optional[int] = None + download_url: str = "" + cover_url: Optional[str] = None + image_urls: list[str] = field(default_factory=list) + duration: float = 0 + tags: list[str] = field(default_factory=list) + + +def _session() -> requests.Session: + session = requests.Session() + session.headers.update( + { + "User-Agent": SHARE_PAGE_UA, + "Accept-Language": "zh-CN,zh;q=0.9", + "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", + } + ) + return session + + +def expand_share_url(share_text: str) -> str: + """从抖音分享文案中提取可访问链接。""" + match = DOUYIN_URL_RE.search((share_text or "").strip()) + if not match: + raise DouyinResolveError("未在输入中找到抖音链接") + return match.group(0).rstrip("/.,;)") + + +def _extract_aweme_id_from_search_url(url: str) -> Optional[str]: + parsed = urlparse(url) + if not parsed.netloc.endswith("douyin.com") or not parsed.path.startswith("/search"): + return None + + params = parse_qs(parsed.query) + for key in ("modal_id", "item_ids"): + for value in params.get(key, []): + match = re.search(r"\d{10,}", value) + if match: + return match.group(0) + return None + + +def normalize_to_share_page(url: str) -> str: + """www.douyin.com 的 video/note 页面转为移动端分享页。""" + note = re.search(r"https?://(?:www\.)?douyin\.com/note/(\d+)", url) + if note: + return f"https://www.iesdouyin.com/share/note/{note.group(1)}/" + video = re.search(r"https?://(?:www\.)?douyin\.com/video/(\d+)", url) + if video: + return f"https://www.iesdouyin.com/share/video/{video.group(1)}/" + search_aweme_id = _extract_aweme_id_from_search_url(url) + if search_aweme_id: + return f"https://www.iesdouyin.com/share/video/{search_aweme_id}/" + return url + + +def resolve_share_page(session: requests.Session, share_url: str) -> tuple[str, str]: + response = session.get(share_url, allow_redirects=True, timeout=30) + response.raise_for_status() + return str(response.url), response.text + + +def extract_aweme_id(page_url: str, html: Optional[str] = None) -> str: + patterns = [ + r"/video/(\d+)", + r"/note/(\d+)", + r"/share/video/(\d+)", + r"/share/note/(\d+)", + r"modal_id=(\d+)", + r"item_ids=(\d+)", + r'"aweme_id"\s*:\s*"?(\d+)"?', + r'"itemId"\s*:\s*"?(\d+)"?', + ] + for pattern in patterns: + match = re.search(pattern, page_url) + if match: + return match.group(1) + if html: + for pattern in patterns: + match = re.search(pattern, html) + if match: + return match.group(1) + raise DouyinResolveError(f"无法从分享页解析作品 ID: {page_url}") + + +def _parse_router_data(html: str) -> Optional[dict[str, Any]]: + match = ROUTER_DATA_RE.search(html) + if not match: + return None + raw = match.group(1).split("")[0].rstrip().rstrip(";") + try: + return json.loads(raw) + except json.JSONDecodeError: + return None + + +def _parse_render_data(html: str) -> Optional[dict[str, Any]]: + match = RENDER_DATA_RE.search(html) + if not match: + return None + try: + return json.loads(unquote(match.group(1))) + except json.JSONDecodeError: + return None + + +def _find_item_list(obj: Any) -> list[dict[str, Any]]: + if isinstance(obj, dict): + item_list = obj.get("item_list") + if isinstance(item_list, list) and item_list: + first = item_list[0] + if isinstance(first, dict) and ( + "aweme_id" in first or "awemeId" in first or "video" in first or "images" in first + ): + return item_list + for value in obj.values(): + found = _find_item_list(value) + if found: + return found + elif isinstance(obj, list): + for item in obj: + found = _find_item_list(item) + if found: + return found + return [] + + +def _pick_url_from_image_node(image: dict[str, Any]) -> Optional[str]: + url_list = image.get("url_list") or [] + if url_list: + return str(url_list[-1]) + download_list = image.get("download_url_list") or [] + if download_list: + return str(download_list[-1]) + return None + + +def _extract_image_urls(item: dict[str, Any]) -> list[str]: + urls: list[str] = [] + seen: set[str] = set() + + def add(url: Optional[str]) -> None: + if url and url not in seen: + seen.add(url) + urls.append(url) + + for image in item.get("images") or []: + if isinstance(image, dict): + add(_pick_url_from_image_node(image)) + + post = item.get("image_post_info") or {} + if isinstance(post, dict): + for image in post.get("images") or []: + if isinstance(image, dict): + add(_pick_url_from_image_node(image)) + + return urls + + +def _has_playable_video(item: dict[str, Any]) -> bool: + video = item.get("video") or {} + if not isinstance(video, dict): + return False + play_addr = video.get("play_addr") or video.get("playAddr") or {} + if not isinstance(play_addr, dict): + return False + return bool(play_addr.get("uri") or play_addr.get("url_list")) + + +def _is_image_note(item: dict[str, Any]) -> bool: + aweme_type = item.get("aweme_type") + if aweme_type in IMAGE_AWEME_TYPES: + return True + return bool(_extract_image_urls(item)) and not _has_playable_video(item) + + +def _build_no_watermark_url(play_addr: dict[str, Any]) -> str: + uri = play_addr.get("uri") or "" + url_list = play_addr.get("url_list") or [] + if uri: + return f"https://aweme.snssdk.com/aweme/v1/play/?video_id={uri}&ratio=720p&line=0" + if url_list: + return str(url_list[0]).replace("playwm", "play") + raise DouyinResolveError("分享页内嵌数据中未找到视频播放地址") + + +def _extract_tags(item: dict[str, Any]) -> list[str]: + tags: list[str] = [] + seen: set[str] = set() + for tag in item.get("text_extra") or item.get("video_tag") or []: + if not isinstance(tag, dict): + continue + name = tag.get("hashtag_name") or tag.get("tag_name") or tag.get("name") + if name and name not in seen: + seen.add(name) + tags.append(str(name)) + return tags + + +def _duration_seconds(raw: Any) -> float: + try: + value = float(raw or 0) + except (TypeError, ValueError): + return 0 + return value / 1000 if value > 10000 else value + + +def _meta_from_aweme_item(item: dict[str, Any], source_url: str) -> DouyinContentMeta: + aweme_id = str(item.get("aweme_id") or item.get("awemeId") or "") + title = (item.get("desc") or item.get("caption") or "").strip() or f"douyin_{aweme_id}" + aweme_type = item.get("aweme_type") + tags = _extract_tags(item) + + author = "" + author_info = item.get("author") or {} + if isinstance(author_info, dict): + author = author_info.get("nickname") or author_info.get("unique_id") or "" + + duration = _duration_seconds(item.get("duration")) + + if _is_image_note(item): + image_urls = _extract_image_urls(item) + if not image_urls: + raise DouyinResolveError("识别为图文,但未找到图片地址") + return DouyinContentMeta( + aweme_id=aweme_id, + title=title, + author=author, + source_url=source_url, + content_type="image", + aweme_type=aweme_type, + cover_url=image_urls[0], + image_urls=image_urls, + duration=duration, + tags=tags, + ) + + video = item.get("video") or {} + if not isinstance(video, dict): + raise DouyinResolveError("分享页内嵌数据中未找到视频节点") + play_addr = video.get("play_addr") or video.get("playAddr") or {} + if not isinstance(play_addr, dict): + raise DouyinResolveError("视频节点缺少 play_addr") + + download_url = _build_no_watermark_url(play_addr) + cover_url = None + for key in ("cover", "origin_cover", "dynamic_cover", "cover_original_scale"): + cover_info = video.get(key) or {} + if isinstance(cover_info, dict): + covers = cover_info.get("url_list") or [] + if covers: + cover_url = str(covers[0]) + break + + for bit_rate in video.get("bit_rate") or []: + if not isinstance(bit_rate, dict): + continue + bit_play = bit_rate.get("play_addr") or {} + if isinstance(bit_play, dict) and bit_play.get("url_list"): + candidate = str(bit_play["url_list"][0]) + if "playwm" not in candidate and ("douyinvod" in candidate or "bytecdn" in candidate): + download_url = candidate + break + + return DouyinContentMeta( + aweme_id=aweme_id, + title=title, + author=author, + source_url=source_url, + content_type="video", + aweme_type=aweme_type, + download_url=download_url, + cover_url=cover_url, + duration=duration, + tags=tags, + ) + + +def parse_share_page_html(html: str, page_url: str, original_share: str) -> DouyinContentMeta: + for parser in (_parse_router_data, _parse_render_data): + payload = parser(html) + if not payload: + continue + items = _find_item_list(payload) + if items: + meta = _meta_from_aweme_item(items[0], original_share) + if meta.aweme_id: + return meta + return DouyinContentMeta( + aweme_id=extract_aweme_id(page_url, html), + title=meta.title, + author=meta.author, + source_url=meta.source_url, + content_type=meta.content_type, + aweme_type=meta.aweme_type, + download_url=meta.download_url, + cover_url=meta.cover_url, + image_urls=meta.image_urls, + duration=meta.duration, + tags=meta.tags, + ) + + raise DouyinResolveError( + "分享页未找到内嵌公开数据(_ROUTER_DATA / RENDER_DATA)。" + "请确认链接有效。" + ) + + +def resolve_douyin_share(share_text: str) -> DouyinContentMeta: + session = _session() + share_url = expand_share_url(share_text) + fetch_url = normalize_to_share_page(share_url) + page_url, html = resolve_share_page(session, fetch_url) + return parse_share_page_html(html, page_url, share_url) + + +def _download_file(url: str, dest: str) -> str: + os.makedirs(os.path.dirname(dest), exist_ok=True) + headers = {"User-Agent": SHARE_PAGE_UA, "Referer": "https://www.iesdouyin.com/"} + with requests.get(url, headers=headers, stream=True, timeout=120) as response: + response.raise_for_status() + with open(dest, "wb") as file: + for chunk in response.iter_content(chunk_size=1024 * 256): + if chunk: + file.write(chunk) + return dest + + +def _extract_audio(video_path: str, audio_path: str) -> None: + subprocess.run( + ["ffmpeg", "-y", "-i", video_path, "-vn", "-acodec", "libmp3lame", audio_path], + check=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + +def _build_result( + meta: DouyinContentMeta, + audio_path: str, + video_path: Optional[str], +) -> AudioDownloadResult: + return AudioDownloadResult( + file_path=audio_path, + title=meta.title, + duration=meta.duration, + cover_url=meta.cover_url, + platform="douyin", + video_id=meta.aweme_id, + raw_info={ + "tags": meta.tags, + "author": meta.author, + "source_url": meta.source_url, + "content_type": meta.content_type, + "image_urls": meta.image_urls, + }, + video_path=video_path, + ) + + +class DouyinDownloader(Downloader): + def __init__(self, cookie=None): + super().__init__() + + def extract_video_id(self, url: str) -> str: + try: + return extract_aweme_id(normalize_to_share_page(expand_share_url(url))) + except DouyinResolveError: + return "" + + def _resolve_meta(self, video_url: str) -> DouyinContentMeta: + try: + return resolve_douyin_share(video_url) + except DouyinResolveError: + raise + except Exception as exc: + raise DouyinResolveError(f"抖音分享页解析失败:{exc}") from exc + + def download( + self, + video_url: str, + output_dir: Union[str, None] = None, + quality: DownloadQuality = "fast", + need_video: Optional[bool] = False, + skip_download: bool = False, + ) -> AudioDownloadResult: + if output_dir is None: + output_dir = get_data_dir() + if not output_dir: + output_dir = self.cache_data + os.makedirs(output_dir, exist_ok=True) + + meta = self._resolve_meta(video_url) + if meta.content_type == "image": + return _build_result(meta, "", None) + + video_path = os.path.join(output_dir, f"{meta.aweme_id}.mp4") + audio_path = os.path.join(output_dir, f"{meta.aweme_id}.mp3") + + if skip_download: + return _build_result(meta, "", None) + + if not os.path.exists(video_path): + _download_file(meta.download_url, video_path) + + if not os.path.exists(audio_path): + try: + _extract_audio(video_path, audio_path) + except subprocess.CalledProcessError as exc: + raise RuntimeError("ffmpeg 转换 MP3 失败") from exc + + return _build_result( + meta, + audio_path, + video_path if need_video or os.path.exists(video_path) else None, + ) + + def download_video(self, video_url: str, output_dir: Union[str, None] = None) -> str: + if output_dir is None: + output_dir = get_data_dir() + if not output_dir: + output_dir = self.cache_data + os.makedirs(output_dir, exist_ok=True) + + meta = self._resolve_meta(video_url) + if meta.content_type == "image": + raise DouyinResolveError("抖音图文内容没有可下载的视频文件") + + video_path = os.path.join(output_dir, f"{meta.aweme_id}.mp4") + if not os.path.exists(video_path): + _download_file(meta.download_url, video_path) + return video_path + + def download_subtitles( + self, + video_url: str, + output_dir: str = None, + langs: list = None, + ) -> Optional[TranscriptResult]: + meta = self._resolve_meta(video_url) + if meta.content_type != "image" or not meta.title: + return None + return TranscriptResult( + language="zh", + full_text=meta.title, + segments=[ + TranscriptSegment( + start=0, + end=meta.duration or 0, + text=meta.title, + ) + ], + ) diff --git a/backend/app/downloaders/douyin_helper/abogus.py b/backend/app/downloaders/douyin_helper/abogus.py new file mode 100644 index 0000000000000000000000000000000000000000..dab19951ef6103cb001edcbe1fc005649baea262 --- /dev/null +++ b/backend/app/downloaders/douyin_helper/abogus.py @@ -0,0 +1,635 @@ +""" +Original Author: +This file is from https://github.com/JoeanAmier/TikTokDownloader +And is licensed under the GNU General Public License v3.0 +If you use this code, please keep this license and the original author information. + +Modified by: +And this file is now a part of the https://github.com/Evil0ctal/Douyin_TikTok_Download_API open-source project. +This project is licensed under the Apache License 2.0, and the original author information is kept. + +Purpose: +This file is used to generate the `a_bogus` parameter for the Douyin Web API. + +Changes Made: +1. Changed the ua_code to compatible with the current config file User-Agent string in https://github.com/Evil0ctal/Douyin_TikTok_Download_API/blob/main/crawlers/douyin/web/config.yaml +""" + +from random import choice +from random import randint +from random import random +from re import compile +from time import time +from urllib.parse import urlencode +from urllib.parse import quote +from gmssl import sm3, func + +__all__ = ["ABogus", ] + + +class ABogus: + __filter = compile(r'%([0-9A-F]{2})') + __arguments = [0, 1, 14] + __ua_key = "\u0000\u0001\u000e" + __end_string = "cus" + __version = [1, 0, 1, 5] + __browser = "1536|742|1536|864|0|0|0|0|1536|864|1536|864|1536|742|24|24|MacIntel" + __reg = [ + 1937774191, + 1226093241, + 388252375, + 3666478592, + 2842636476, + 372324522, + 3817729613, + 2969243214, + ] + __str = { + "s0": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=", + "s1": "Dkdpgh4ZKsQB80/Mfvw36XI1R25+WUAlEi7NLboqYTOPuzmFjJnryx9HVGcaStCe=", + "s2": "Dkdpgh4ZKsQB80/Mfvw36XI1R25-WUAlEi7NLboqYTOPuzmFjJnryx9HVGcaStCe=", + "s3": "ckdp1h4ZKsUB80/Mfvw36XIgR25+WQAlEi7NLboqYTOPuzmFjJnryx9HVGDaStCe", + "s4": "Dkdpgh2ZmsQB80/MfvV36XI1R45-WUAlEixNLwoqYTOPuzKFjJnry79HbGcaStCe", + } + + def __init__(self, + # user_agent: str = USERAGENT, + platform: str = None, ): + self.chunk = [] + self.size = 0 + self.reg = self.__reg[:] + # self.ua_code = self.generate_ua_code(user_agent) + # Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/90.0.4430.212 Safari/537.36 + self.ua_code = [ + 76, + 98, + 15, + 131, + 97, + 245, + 224, + 133, + 122, + 199, + 241, + 166, + 79, + 34, + 90, + 191, + 128, + 126, + 122, + 98, + 66, + 11, + 14, + 40, + 49, + 110, + 110, + 173, + 67, + 96, + 138, + 252] + self.browser = self.generate_browser_info( + platform) if platform else self.__browser + self.browser_len = len(self.browser) + self.browser_code = self.char_code_at(self.browser) + + @classmethod + def list_1(cls, random_num=None, a=170, b=85, c=45, ) -> list: + return cls.random_list( + random_num, + a, + b, + 1, + 2, + 5, + c & a, + ) + + @classmethod + def list_2(cls, random_num=None, a=170, b=85, ) -> list: + return cls.random_list( + random_num, + a, + b, + 1, + 0, + 0, + 0, + ) + + @classmethod + def list_3(cls, random_num=None, a=170, b=85, ) -> list: + return cls.random_list( + random_num, + a, + b, + 1, + 0, + 5, + 0, + ) + + @staticmethod + def random_list( + a: float = None, + b=170, + c=85, + d=0, + e=0, + f=0, + g=0, + ) -> list: + r = a or (random() * 10000) + v = [ + r, + int(r) & 255, + int(r) >> 8, + ] + s = v[1] & b | d + v.append(s) + s = v[1] & c | e + v.append(s) + s = v[2] & b | f + v.append(s) + s = v[2] & c | g + v.append(s) + return v[-4:] + + @staticmethod + def from_char_code(*args): + return "".join(chr(code) for code in args) + + @classmethod + def generate_string_1( + cls, + random_num_1=None, + random_num_2=None, + random_num_3=None, + ): + return cls.from_char_code(*cls.list_1(random_num_1)) + cls.from_char_code( + *cls.list_2(random_num_2)) + cls.from_char_code(*cls.list_3(random_num_3)) + + def generate_string_2( + self, + url_params: str, + method="GET", + start_time=0, + end_time=0, + ) -> str: + a = self.generate_string_2_list( + url_params, + method, + start_time, + end_time, + ) + e = self.end_check_num(a) + a.extend(self.browser_code) + a.append(e) + return self.rc4_encrypt(self.from_char_code(*a), "y") + + def generate_string_2_list( + self, + url_params: str, + method="GET", + start_time=0, + end_time=0, + ) -> list: + start_time = start_time or int(time() * 1000) + end_time = end_time or (start_time + randint(4, 8)) + params_array = self.generate_params_code(url_params) + method_array = self.generate_method_code(method) + return self.list_4( + (end_time >> 24) & 255, + params_array[21], + self.ua_code[23], + (end_time >> 16) & 255, + params_array[22], + self.ua_code[24], + (end_time >> 8) & 255, + (end_time >> 0) & 255, + (start_time >> 24) & 255, + (start_time >> 16) & 255, + (start_time >> 8) & 255, + (start_time >> 0) & 255, + method_array[21], + method_array[22], + int(end_time / 256 / 256 / 256 / 256) >> 0, + int(start_time / 256 / 256 / 256 / 256) >> 0, + self.browser_len, + ) + + @staticmethod + def reg_to_array(a): + o = [0] * 32 + for i in range(8): + c = a[i] + o[4 * i + 3] = (255 & c) + c >>= 8 + o[4 * i + 2] = (255 & c) + c >>= 8 + o[4 * i + 1] = (255 & c) + c >>= 8 + o[4 * i] = (255 & c) + + return o + + def compress(self, a): + f = self.generate_f(a) + i = self.reg[:] + for o in range(64): + c = self.de(i[0], 12) + i[4] + self.de(self.pe(o), o) + c = (c & 0xFFFFFFFF) + c = self.de(c, 7) + s = (c ^ self.de(i[0], 12)) & 0xFFFFFFFF + + u = self.he(o, i[0], i[1], i[2]) + u = (u + i[3] + s + f[o + 68]) & 0xFFFFFFFF + + b = self.ve(o, i[4], i[5], i[6]) + b = (b + i[7] + c + f[o]) & 0xFFFFFFFF + + i[3] = i[2] + i[2] = self.de(i[1], 9) + i[1] = i[0] + i[0] = u + + i[7] = i[6] + i[6] = self.de(i[5], 19) + i[5] = i[4] + i[4] = (b ^ self.de(b, 9) ^ self.de(b, 17)) & 0xFFFFFFFF + + for l in range(8): + self.reg[l] = (self.reg[l] ^ i[l]) & 0xFFFFFFFF + + @classmethod + def generate_f(cls, e): + r = [0] * 132 + + for t in range(16): + r[t] = (e[4 * t] << 24) | (e[4 * t + 1] << + 16) | (e[4 * t + 2] << 8) | e[4 * t + 3] + r[t] &= 0xFFFFFFFF + + for n in range(16, 68): + a = r[n - 16] ^ r[n - 9] ^ cls.de(r[n - 3], 15) + a = a ^ cls.de(a, 15) ^ cls.de(a, 23) + r[n] = (a ^ cls.de(r[n - 13], 7) ^ r[n - 6]) & 0xFFFFFFFF + + for n in range(68, 132): + r[n] = (r[n - 68] ^ r[n - 64]) & 0xFFFFFFFF + + return r + + @staticmethod + def pad_array(arr, length=60): + while len(arr) < length: + arr.append(0) + return arr + + def fill(self, length=60): + size = 8 * self.size + self.chunk.append(128) + self.chunk = self.pad_array(self.chunk, length) + for i in range(4): + self.chunk.append((size >> 8 * (3 - i)) & 255) + + @staticmethod + def list_4( + a: int, + b: int, + c: int, + d: int, + e: int, + f: int, + g: int, + h: int, + i: int, + j: int, + k: int, + m: int, + n: int, + o: int, + p: int, + q: int, + r: int, + ) -> list: + return [ + 44, + a, + 0, + 0, + 0, + 0, + 24, + b, + n, + 0, + c, + d, + 0, + 0, + 0, + 1, + 0, + 239, + e, + o, + f, + g, + 0, + 0, + 0, + 0, + h, + 0, + 0, + 14, + i, + j, + 0, + k, + m, + 3, + p, + 1, + q, + 1, + r, + 0, + 0, + 0] + + @staticmethod + def end_check_num(a: list): + r = 0 + for i in a: + r ^= i + return r + + @classmethod + def decode_string(cls, url_string, ): + decoded = cls.__filter.sub(cls.replace_func, url_string) + return decoded + + @staticmethod + def replace_func(match): + return chr(int(match.group(1), 16)) + + @staticmethod + def de(e, r): + r %= 32 + return ((e << r) & 0xFFFFFFFF) | (e >> (32 - r)) + + @staticmethod + def pe(e): + return 2043430169 if 0 <= e < 16 else 2055708042 + + @staticmethod + def he(e, r, t, n): + if 0 <= e < 16: + return (r ^ t ^ n) & 0xFFFFFFFF + elif 16 <= e < 64: + return (r & t | r & n | t & n) & 0xFFFFFFFF + raise ValueError + + @staticmethod + def ve(e, r, t, n): + if 0 <= e < 16: + return (r ^ t ^ n) & 0xFFFFFFFF + elif 16 <= e < 64: + return (r & t | ~r & n) & 0xFFFFFFFF + raise ValueError + + @staticmethod + def convert_to_char_code(a): + d = [] + for i in a: + d.append(ord(i)) + return d + + @staticmethod + def split_array(arr, chunk_size=64): + result = [] + for i in range(0, len(arr), chunk_size): + result.append(arr[i:i + chunk_size]) + return result + + @staticmethod + def char_code_at(s): + return [ord(char) for char in s] + + def write(self, e, ): + self.size = len(e) + if isinstance(e, str): + e = self.decode_string(e) + e = self.char_code_at(e) + if len(e) <= 64: + self.chunk = e + else: + chunks = self.split_array(e, 64) + for i in chunks[:-1]: + self.compress(i) + self.chunk = chunks[-1] + + def reset(self, ): + self.chunk = [] + self.size = 0 + self.reg = self.__reg[:] + + def sum(self, e, length=60): + self.reset() + self.write(e) + self.fill(length) + self.compress(self.chunk) + return self.reg_to_array(self.reg) + + @classmethod + def generate_result_unit(cls, n, s): + r = "" + for i, j in zip(range(18, -1, -6), (16515072, 258048, 4032, 63)): + r += cls.__str[s][(n & j) >> i] + return r + + @classmethod + def generate_result_end(cls, s, e="s4"): + r = "" + b = ord(s[120]) << 16 + r += cls.__str[e][(b & 16515072) >> 18] + r += cls.__str[e][(b & 258048) >> 12] + r += "==" + return r + + @classmethod + def generate_result(cls, s, e="s4"): + # r = "" + # for i in range(len(s)//4): + # b = ((ord(s[i * 3]) << 16) | (ord(s[i * 3 + 1])) + # << 8) | ord(s[i * 3 + 2]) + # r += cls.generate_result_unit(b, e) + # return r + + r = [] + + for i in range(0, len(s), 3): + if i + 2 < len(s): + n = ( + (ord(s[i]) << 16) + | (ord(s[i + 1]) << 8) + | ord(s[i + 2]) + ) + elif i + 1 < len(s): + n = (ord(s[i]) << 16) | ( + ord(s[i + 1]) << 8 + ) + else: + n = ord(s[i]) << 16 + + for j, k in zip(range(18, -1, -6), + (0xFC0000, 0x03F000, 0x0FC0, 0x3F)): + if j == 6 and i + 1 >= len(s): + break + if j == 0 and i + 2 >= len(s): + break + r.append(cls.__str[e][(n & k) >> j]) + + r.append("=" * ((4 - len(r) % 4) % 4)) + return "".join(r) + + @classmethod + def generate_args_code(cls): + a = [] + for j in range(24, -1, -8): + a.append(cls.__arguments[0] >> j) + a.append(cls.__arguments[1] / 256) + a.append(cls.__arguments[1] % 256) + a.append(cls.__arguments[1] >> 24) + a.append(cls.__arguments[1] >> 16) + for j in range(24, -1, -8): + a.append(cls.__arguments[2] >> j) + return [int(i) & 255 for i in a] + + def generate_method_code(self, method: str = "GET") -> list[int]: + return self.sm3_to_array(self.sm3_to_array(method + self.__end_string)) + # return self.sum(self.sum(method + self.__end_string)) + + def generate_params_code(self, params: str) -> list[int]: + return self.sm3_to_array(self.sm3_to_array(params + self.__end_string)) + # return self.sum(self.sum(params + self.__end_string)) + + @classmethod + def sm3_to_array(cls, data: str | list) -> list[int]: + """ + 代码参考: https://github.com/Johnserf-Seed/f2/blob/main/f2/utils/abogus.py + + 计算请求体的 SM3 哈希值,并将结果转换为整数数组 + Calculate the SM3 hash value of the request body and convert the result to an array of integers + + Args: + data (Union[str, List[int]]): 输入数据 (Input data). + + Returns: + List[int]: 哈希值的整数数组 (Array of integers representing the hash value). + """ + + if isinstance(data, str): + b = data.encode("utf-8") + else: + b = bytes(data) # 将 List[int] 转换为字节数组 + + # 将字节数组转换为适合 sm3.sm3_hash 函数处理的列表格式 + h = sm3.sm3_hash(func.bytes_to_list(b)) + + # 将十六进制字符串结果转换为十进制整数列表 + return [int(h[i: i + 2], 16) for i in range(0, len(h), 2)] + + @classmethod + def generate_browser_info(cls, platform: str = "Win32") -> str: + inner_width = randint(1280, 1920) + inner_height = randint(720, 1080) + outer_width = randint(inner_width, 1920) + outer_height = randint(inner_height, 1080) + screen_x = 0 + screen_y = choice((0, 30)) + value_list = [ + inner_width, + inner_height, + outer_width, + outer_height, + screen_x, + screen_y, + 0, + 0, + outer_width, + outer_height, + outer_width, + outer_height, + inner_width, + inner_height, + 24, + 24, + platform, + ] + return "|".join(str(i) for i in value_list) + + @staticmethod + def rc4_encrypt(plaintext, key): + s = list(range(256)) + j = 0 + + for i in range(256): + j = (j + s[i] + ord(key[i % len(key)])) % 256 + s[i], s[j] = s[j], s[i] + + i = 0 + j = 0 + cipher = [] + + for k in range(len(plaintext)): + i = (i + 1) % 256 + j = (j + s[i]) % 256 + s[i], s[j] = s[j], s[i] + t = (s[i] + s[j]) % 256 + cipher.append(chr(s[t] ^ ord(plaintext[k]))) + + return ''.join(cipher) + + def get_value(self, + url_params: dict | str, + method="GET", + start_time=0, + end_time=0, + random_num_1=None, + random_num_2=None, + random_num_3=None, + ) -> str: + string_1 = self.generate_string_1( + random_num_1, + random_num_2, + random_num_3, + ) + string_2 = self.generate_string_2(urlencode(url_params) if isinstance( + url_params, dict) else url_params, method, start_time, end_time, ) + string = string_1 + string_2 + # return self.generate_result( + # string, "s4") + self.generate_result_end(string, "s4") + return self.generate_result(string, "s4") + + +if __name__ == "__main__": + bogus = ABogus() + USERAGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/90.0.4430.212 Safari/537.36" + url_str = "https://www.douyin.com/aweme/v1/web/aweme/detail/?device_platform=webapp&aid=6383&channel=channel_pc_web&pc_client_type=1&version_code=190500&version_name=19.5.0&cookie_enabled=true&browser_language=zh-CN&browser_platform=Win32&browser_name=Firefox&browser_online=true&engine_name=Gecko&os_name=Windows&os_version=10&platform=PC&screen_width=1920&screen_height=1080&browser_version=124.0&engine_version=122.0.0.0&cpu_core_num=12&device_memory=8&aweme_id=7345492945006595379" + # 将url参数转换为字典 + url_params = dict([param.split("=") + for param in url_str.split("?")[1].split("&")]) + print(f"URL参数: {url_params}") + a_bogus = bogus.get_value(url_params, ) + # 使用url编码a_bogus + a_bogus = quote(a_bogus, safe='') + print(a_bogus) + print(USERAGENT) diff --git a/backend/app/downloaders/generic_downloader.py b/backend/app/downloaders/generic_downloader.py new file mode 100644 index 0000000000000000000000000000000000000000..a0c021179cd2ddedde0302d816f50eb0245fd93d --- /dev/null +++ b/backend/app/downloaders/generic_downloader.py @@ -0,0 +1,128 @@ +"""通用 yt-dlp 下载器:用于用户在「下载配置」里登记的自定义平台。 + +不做任何站点特定逻辑——完全依赖 yt-dlp 内置 extractor。只把: + - 该平台的 Cookie/cookies-from-browser 注入 ydl_opts + - 全局代理注入 ydl_opts +""" +import logging +import os +import tempfile +from abc import ABC +from typing import Optional, Union + +import yt_dlp + +from app.downloaders.base import Downloader, DownloadQuality +from app.models.notes_model import AudioDownloadResult +from app.services.cookie_manager import CookieConfigManager +from app.services.proxy_config_manager import ProxyConfigManager +from app.utils.path_helper import get_data_dir + +logger = logging.getLogger(__name__) + + +class GenericYtdlpDownloader(Downloader, ABC): + """对任意 yt-dlp 支持站点的薄封装。按平台 key 读取 cookie 配置。""" + + def __init__(self, platform: str, cookie_domain: Optional[str] = None): + super().__init__() + self.platform = platform + # cookie 文件里 Netscape 格式需要 domain;不知道就用通用 . 让 yt-dlp 自己挑 + self.cookie_domain = cookie_domain or f".{platform}.com" + mgr = CookieConfigManager() + self._cookie = mgr.get(platform) + self._browser = mgr.get_browser(platform) + self._cookiefile = None if self._browser else self._write_netscape_cookie_file() + + def _write_netscape_cookie_file(self) -> Optional[str]: + if not self._cookie: + return None + lines = ["# Netscape HTTP Cookie File\n"] + for pair in self._cookie.split("; "): + if "=" in pair: + k, v = pair.split("=", 1) + lines.append(f"{self.cookie_domain}\tTRUE\t/\tFALSE\t0\t{k}\t{v}\n") + tmp = tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False, encoding='utf-8') + tmp.writelines(lines) + tmp.close() + logger.info("已生成 [%s] Netscape Cookie 文件: %s", self.platform, tmp.name) + return tmp.name + + def _apply_ydl_extras(self, ydl_opts: dict) -> None: + proxy = ProxyConfigManager().get_proxy_url() + if proxy: + ydl_opts['proxy'] = proxy + if self._browser: + ydl_opts['cookiesfrombrowser'] = (self._browser,) + elif self._cookiefile: + ydl_opts['cookiefile'] = self._cookiefile + + def download( + self, + video_url: str, + output_dir: Union[str, None] = None, + quality: DownloadQuality = "fast", + need_video: Optional[bool] = False, + skip_download: bool = False, + ) -> AudioDownloadResult: + if output_dir is None: + output_dir = get_data_dir() + if not output_dir: + output_dir = self.cache_data + os.makedirs(output_dir, exist_ok=True) + + output_path = os.path.join(output_dir, "%(id)s.%(ext)s") + ydl_opts = { + 'format': 'bestaudio/best', + 'outtmpl': output_path, + 'noplaylist': True, + 'quiet': False, + } + if skip_download: + ydl_opts['skip_download'] = True + self._apply_ydl_extras(ydl_opts) + + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + info = ydl.extract_info(video_url, download=not skip_download) + video_id = info.get("id") or "unknown" + title = info.get("title") or self.platform + duration = info.get("duration", 0) + cover_url = info.get("thumbnail") + ext = info.get("ext", "mp3") + audio_path = os.path.join(output_dir, f"{video_id}.{ext}") + + return AudioDownloadResult( + file_path=audio_path, + title=title, + duration=duration, + cover_url=cover_url, + platform=self.platform, + video_id=video_id, + raw_info={'tags': info.get('tags')}, + video_path=None, + ) + + def download_video( + self, + video_url: str, + output_dir: Union[str, None] = None, + ) -> str: + if output_dir is None: + output_dir = get_data_dir() + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, "%(id)s.%(ext)s") + ydl_opts = { + 'format': 'bestvideo+bestaudio/best', + 'outtmpl': output_path, + 'noplaylist': True, + 'quiet': False, + 'merge_output_format': 'mp4', + } + self._apply_ydl_extras(ydl_opts) + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + info = ydl.extract_info(video_url, download=True) + video_id = info.get("id") + video_path = os.path.join(output_dir, f"{video_id}.mp4") + if not os.path.exists(video_path): + raise FileNotFoundError(f"视频文件未找到: {video_path}") + return video_path diff --git a/backend/app/downloaders/kuaishou_downloader.py b/backend/app/downloaders/kuaishou_downloader.py new file mode 100644 index 0000000000000000000000000000000000000000..8b1d2a6e240825ae00599b8ea2b449888944e9aa --- /dev/null +++ b/backend/app/downloaders/kuaishou_downloader.py @@ -0,0 +1,97 @@ +import os +import subprocess +from abc import ABC +from typing import Union, Optional + +import requests + +from app.downloaders.base import Downloader +from app.downloaders.kuaishou_helper.kuaishou import KuaiShou +from app.enmus.note_enums import DownloadQuality +from app.models.audio_model import AudioDownloadResult +from app.utils.path_helper import get_data_dir + + +class KuaiShouDownloader(Downloader, ABC): + def __init__(self): + super().__init__() + + def download( + self, + video_url: str, + output_dir: Union[str, None] = None, + quality: str = "fast", + need_video: Optional[bool] = False + ) -> AudioDownloadResult: + if output_dir is None: + output_dir = get_data_dir() + if not output_dir: + output_dir = self.cache_data + os.makedirs(output_dir, exist_ok=True) + + ks = KuaiShou() + video_raw_info = ks.run(video_url) + print(video_raw_info) + photo_info = video_raw_info['visionVideoDetail']['photo'] + video_id = photo_info['id'] + title = photo_info['caption'].strip().replace('\n', '').replace(' ', '_')[:50] + mp4_path = os.path.join(output_dir, f"{video_id}.mp4") + mp3_path = os.path.join(output_dir, f"{video_id}.mp3") + + if os.path.exists(mp3_path): + print(f"[已存在] 跳过下载: {mp3_path}") + return AudioDownloadResult( + file_path=mp3_path, + title=title, + duration=photo_info['duration'], + cover_url=photo_info['coverUrl'], + platform="kuaishou", + video_id=video_id, + raw_info={ + 'tags': ','.join(tag['name'] for tag in video_raw_info.get('tags', []) if tag.get('name')) + }, + video_path=mp4_path + ) + + # 下载 mp4 视频 + resp = requests.get(photo_info['photoUrl'], stream=True) + if resp.status_code == 200: + with open(mp4_path, "wb") as f: + for chunk in resp.iter_content(1024 * 1024): + f.write(chunk) + else: + raise Exception(f"视频下载失败: {resp.status_code}") + + # 使用 ffmpeg 转换为 mp3 + try: + subprocess.run([ + "ffmpeg", "-y", "-i", mp4_path, "-vn", "-acodec", "libmp3lame", mp3_path + ], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + except subprocess.CalledProcessError: + raise Exception("ffmpeg 转换 MP3 失败") + + return AudioDownloadResult( + file_path=mp3_path, + title=photo_info['caption'], + duration=photo_info['duration'], + cover_url=photo_info['coverUrl'], + platform="kuaishou", + video_id=video_id, + raw_info={ + 'tags': ','.join(tag['name'] for tag in video_raw_info.get('tags', []) if tag.get('name')) + }, + video_path=mp4_path + ) + + def download_video( + self, + video_url: str, + output_dir: Union[str, None] = None, + ) -> str: + print('self.download(video_url, output_dir).video_path',self.download(video_url, output_dir).video_path) + return self.download(video_url, output_dir).video_path + + +if __name__ == '__main__': + ks = KuaiShouDownloader() + ks.download('https://v.kuaishou.com/2vBqX74 王宝强携手刘昊然、岳云鹏上演精彩名场面 全程高能 看一遍笑一遍 "唐探1900 "快成长计划 ...更多') \ No newline at end of file diff --git a/backend/app/downloaders/kuaishou_helper/__init__.py b/backend/app/downloaders/kuaishou_helper/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backend/app/downloaders/kuaishou_helper/kuaishou.py b/backend/app/downloaders/kuaishou_helper/kuaishou.py new file mode 100644 index 0000000000000000000000000000000000000000..5074e9c85bbd6c4457030779dea9681a6ecfcfda --- /dev/null +++ b/backend/app/downloaders/kuaishou_helper/kuaishou.py @@ -0,0 +1,101 @@ +import logging +import os +import re + +import requests +from dotenv import load_dotenv + +from app.services.cookie_manager import CookieConfigManager +from app.utils.logger import get_logger +KUAISHOU_API_BASE = 'https://www.kuaishou.com/graphql' +KUAISHOU_URL = "https://www.kuaishou.com/" +load_dotenv() +headers = { + 'Accept-Language': 'zh-CN,zh;q=0.9', + 'Cache-Control': 'no-cache', + 'Connection': 'keep-alive', + # 'Cookie': 'did=web_9e8cfa4403000587b9e7d67233e6b04c; didv=1719811812378; kpf=PC_WEB; clientid=3; kpn=KUAISHOU_VISION', + 'Origin': 'https://www.kuaishou.com', + 'Pragma': 'no-cache', + 'Referer': 'https://www.kuaishou.com/', + 'Sec-Fetch-Dest': 'empty', + 'Sec-Fetch-Mode': 'cors', + 'Sec-Fetch-Site': 'same-origin', + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36', + 'accept': '*/*', + 'content-type': 'application/json', + 'sec-ch-ua': '"Not/A)Brand";v="8", "Chromium";v="126", "Google Chrome";v="126"', + 'sec-ch-ua-mobile': '?0', + 'sec-ch-ua-platform': '"Windows"', + # 'Cookie':cookies.strip() +} + +logger = get_logger(__name__) + +cfm=CookieConfigManager() +class KuaiShou: + def __init__(self): + self.header = headers.copy() + self.cookie = None + + @staticmethod + def _extract_kuaishou_link(text): + + url = re.findall('http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', text) + return url[0] + + def get_photo_id(self, url): + response = requests.get(url, allow_redirects=True, headers=self.header) + real_url = response.url + # 提取short—video/后面的id + pattern = re.compile(r'short-video/(\w+)') + match = pattern.search(real_url) + return match.group().split('/')[1] + + def get_temp_cookies(self): + is_exist = cfm.get('kuaishou') + print(is_exist) + if is_exist: + return is_exist + res = requests.get(url=KUAISHOU_URL, headers=self.header, allow_redirects=True) + cookie_string = '; '.join([f"{k}={v}" for k, v in res.cookies.get_dict().items()]) + return cookie_string + + def get_video_details(self, url, photo_id): + json_data = { + 'operationName': 'visionVideoDetail', + "variables": {"photoId": photo_id, "page": "detail"}, + "query": "query visionVideoDetail($photoId: String, $type: String, $page: String, $webPageArea: String) {\n visionVideoDetail(photoId: $photoId, type: $type, page: $page, webPageArea: $webPageArea) {\n status\n type\n author {\n id\n name\n following\n headerUrl\n __typename\n }\n photo {\n id\n duration\n caption\n likeCount\n realLikeCount\n coverUrl\n photoUrl\n liked\n timestamp\n expTag\n llsid\n viewCount\n videoRatio\n stereoType\n croppedPhotoUrl\n manifest {\n mediaType\n businessType\n version\n adaptationSet {\n id\n duration\n representation {\n id\n defaultSelect\n backupUrl\n codecs\n url\n height\n width\n avgBitrate\n maxBitrate\n m3u8Slice\n qualityType\n qualityLabel\n frameRate\n featureP2sp\n hidden\n disableAdaptive\n __typename\n }\n __typename\n }\n __typename\n }\n __typename\n }\n tags {\n type\n name\n __typename\n }\n commentLimit {\n canAddComment\n __typename\n }\n llsid\n danmakuSwitch\n __typename\n }\n}\n" + } + response = requests.post(url=KUAISHOU_API_BASE, headers=self.header, json=json_data) + if response.status_code == 200: + response.raise_for_status() + + return response.json() + else: + return None + + def run(self, url): + real_url = self._extract_kuaishou_link(url) + if not real_url: + logger.error(f"快手视频 URL 解析失败 {url}") + + cookies = self.get_temp_cookies() + if not cookies: + logger.error(f"快手视频 cookies 解析失败 {url},请考虑设置环境变量 KUAISHOU_COOKIES") + + self.header['Cookie'] = cookies.strip() + photo_id = self.get_photo_id(real_url) + if photo_id is None: + logger.error(f"快手视频 ID 解析失败 {url}") + video_details = self.get_video_details(real_url, photo_id) + print(video_details) + if video_details is None: + logger.error(f"快手视频详情解析失败 {url}") + return video_details['data'] + + +if __name__ == '__main__': + ks = KuaiShou() + ks.run( + 'https://v.kuaishou.com/2vBqX74 王宝强携手刘昊然、岳云鹏上演精彩名场面 全程高能 看一遍笑一遍 "唐探1900 "快成长计划 ...更多') diff --git a/backend/app/downloaders/local_downloader.py b/backend/app/downloaders/local_downloader.py new file mode 100644 index 0000000000000000000000000000000000000000..2808fb6cb0ec7bc2c5d1e74495a3915a5b15ec74 --- /dev/null +++ b/backend/app/downloaders/local_downloader.py @@ -0,0 +1,137 @@ +import os +import subprocess +from abc import ABC +from typing import Optional + +from app.downloaders.base import Downloader +from app.enmus.note_enums import DownloadQuality +from app.models.audio_model import AudioDownloadResult +import os +import subprocess + +from app.utils.video_helper import save_cover_to_static + + +class LocalDownloader(Downloader, ABC): + def __init__(self): + + super().__init__() + + + def extract_cover(self, input_path: str, output_dir: Optional[str] = None) -> str: + """ + 从本地视频文件中提取一张封面图(默认取第一帧) + :param input_path: 输入视频路径 + :param output_dir: 输出目录,默认和视频同目录 + :return: 提取出的封面图片路径 + """ + if not os.path.exists(input_path): + raise FileNotFoundError(f"输入文件不存在: {input_path}") + + if output_dir is None: + output_dir = os.path.dirname(input_path) + + base_name = os.path.splitext(os.path.basename(input_path))[0] + output_path = os.path.join(output_dir, f"{base_name}_cover.jpg") + + try: + command = [ + 'ffmpeg', + '-i', input_path, + '-ss', '00:00:01', # 跳到视频第1秒,防止黑屏 + '-vframes', '1', # 只截取一帧 + '-q:v', '2', # 输出质量高一点(qscale,2是很高) + '-y', # 覆盖 + output_path + ] + subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) + + if not os.path.exists(output_path): + raise RuntimeError(f"封面图片生成失败: {output_path}") + + return output_path + except subprocess.CalledProcessError as e: + raise RuntimeError(f"提取封面失败: {output_path}") from e + + def convert_to_mp3(self,input_path: str, output_path: str = None) -> str: + """ + 将本地视频文件转为 MP3 音频文件 + :param input_path: 输入文件路径(如 .mp4) + :param output_path: 输出文件路径(可选,默认同目录同名 .mp3) + :return: 生成的 mp3 文件路径 + """ + if not os.path.exists(input_path): + raise FileNotFoundError(f"输入文件不存在: {input_path}") + + if output_path is None: + base, _ = os.path.splitext(input_path) + output_path = base + ".mp3" + try: + # 调用 ffmpeg 转换 + command = [ + 'ffmpeg', + '-i', input_path, + '-vn', # 不要视频流 + '-acodec', 'libmp3lame', # 使用mp3编码 + '-y', # 覆盖输出文件 + output_path + ] + + subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True) + + if not os.path.exists(output_path): + raise RuntimeError(f"mp3 文件生成失败: {output_path}") + + return output_path + except subprocess.CalledProcessError as e: + raise RuntimeError(f"mp3 文件生成失败: {output_path}") from e + def download_video(self, video_url: str, output_dir: str = None) -> str: + """ + 处理本地文件路径,返回视频文件路径 + """ + if video_url.startswith('/uploads'): + project_root = os.getcwd() + video_url = os.path.join(project_root, video_url.lstrip('/')) + video_url = os.path.normpath(video_url) + + if not os.path.exists(video_url): + raise FileNotFoundError() + return video_url + def download( + self, + video_url: str, + output_dir: str = None, + quality: DownloadQuality = "fast", + need_video: Optional[bool] = False + ) -> AudioDownloadResult: + """ + 处理本地文件路径,返回音频元信息 + """ + if video_url.startswith('/uploads'): + project_root = os.getcwd() + video_url = os.path.join(project_root, video_url.lstrip('/')) + video_url = os.path.normpath(video_url) + + if not os.path.exists(video_url): + raise FileNotFoundError(f"本地文件不存在: {video_url}") + + file_name = os.path.basename(video_url) + title, _ = os.path.splitext(file_name) + print(title, file_name,video_url) + file_path=self.convert_to_mp3(video_url) + cover_path = self.extract_cover(video_url) + cover_url = save_cover_to_static(cover_path) + + print('file——path',file_path) + return AudioDownloadResult( + file_path=file_path, + title=title, + duration=0, # 可选:后续加上读取时长 + cover_url=cover_url, # 暂无封面 + platform="local", + video_id=title, + raw_info={ + 'path': file_path + }, + video_path=None + ) diff --git a/backend/app/downloaders/xiaohongshu_downloader.py b/backend/app/downloaders/xiaohongshu_downloader.py new file mode 100644 index 0000000000000000000000000000000000000000..75ff5f666461b3a6074e463f8c99a5dceecd52bd --- /dev/null +++ b/backend/app/downloaders/xiaohongshu_downloader.py @@ -0,0 +1,133 @@ +"""小红书下载器:基于 yt-dlp 内置 XiaoHongShu extractor。 + +URL 模式: + - https://www.xiaohongshu.com/explore/{id} + - https://www.xiaohongshu.com/discovery/item/{id} + - 短链 xhslink.com/xxx 由 yt-dlp 自行跟随重定向 + +小红书很多内容是图文笔记(无视频/音频)。无视频的会触发 yt-dlp 报「请求格式不可用」, +前端会展示生成失败——这是预期行为,不强行兜底。 +""" +import os +import logging +import tempfile +from abc import ABC +from typing import Union, Optional + +import yt_dlp + +from app.downloaders.base import Downloader, DownloadQuality +from app.models.notes_model import AudioDownloadResult +from app.services.cookie_manager import CookieConfigManager +from app.utils.path_helper import get_data_dir +from app.utils.url_parser import extract_video_id, clean_url + +logger = logging.getLogger(__name__) + + +class XiaohongshuDownloader(Downloader, ABC): + def __init__(self): + super().__init__() + self._cookie_mgr = CookieConfigManager() + self._cookie = self._cookie_mgr.get('xiaohongshu') + self._browser = self._cookie_mgr.get_browser('xiaohongshu') + self._cookiefile = None if self._browser else self._write_netscape_cookie_file() + + def _write_netscape_cookie_file(self) -> Optional[str]: + if not self._cookie: + logger.warning("小红书 Cookie 未配置,部分内容可能下载失败") + return None + lines = ["# Netscape HTTP Cookie File\n"] + for pair in self._cookie.split("; "): + if "=" in pair: + key, value = pair.split("=", 1) + lines.append(f".xiaohongshu.com\tTRUE\t/\tFALSE\t0\t{key}\t{value}\n") + tmp = tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False, encoding='utf-8') + tmp.writelines(lines) + tmp.close() + logger.info("已生成小红书 Netscape Cookie 文件: %s (条目: %d)", tmp.name, len(lines) - 1) + return tmp.name + + def _apply_cookie(self, ydl_opts: dict) -> None: + if self._browser: + ydl_opts['cookiesfrombrowser'] = (self._browser,) + logger.info(f"小红书使用 cookies-from-browser: {self._browser}") + elif self._cookiefile: + ydl_opts['cookiefile'] = self._cookiefile + + def download( + self, + video_url: str, + output_dir: Union[str, None] = None, + quality: DownloadQuality = "fast", + need_video: Optional[bool] = False, + skip_download: bool = False, + ) -> AudioDownloadResult: + # 从分享文案中提取干净链接(标题+不可见字符+短链 整段粘贴也能用) + video_url = clean_url(video_url) + if output_dir is None: + output_dir = get_data_dir() + if not output_dir: + output_dir = self.cache_data + os.makedirs(output_dir, exist_ok=True) + + output_path = os.path.join(output_dir, "%(id)s.%(ext)s") + ydl_opts = { + 'format': 'bestaudio/best', + 'outtmpl': output_path, + 'noplaylist': True, + 'quiet': False, + } + if skip_download: + ydl_opts['skip_download'] = True + self._apply_cookie(ydl_opts) + + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + info = ydl.extract_info(video_url, download=not skip_download) + video_id = info.get("id") + title = info.get("title") + duration = info.get("duration", 0) + cover_url = info.get("thumbnail") + ext = info.get("ext", "mp3") + audio_path = os.path.join(output_dir, f"{video_id}.{ext}") + + return AudioDownloadResult( + file_path=audio_path, + title=title, + duration=duration, + cover_url=cover_url, + platform="xiaohongshu", + video_id=video_id, + raw_info={'tags': info.get('tags')}, + video_path=None, + ) + + def download_video( + self, + video_url: str, + output_dir: Union[str, None] = None, + ) -> str: + video_url = clean_url(video_url) + if output_dir is None: + output_dir = get_data_dir() + video_id = extract_video_id(video_url, "xiaohongshu") + video_path = os.path.join(output_dir, f"{video_id}.mp4") + if os.path.exists(video_path): + return video_path + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, "%(id)s.%(ext)s") + ydl_opts = { + 'format': 'bestvideo+bestaudio/best', + 'outtmpl': output_path, + 'noplaylist': True, + 'quiet': False, + 'merge_output_format': 'mp4', + } + self._apply_cookie(ydl_opts) + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + info = ydl.extract_info(video_url, download=True) + video_id = info.get("id") + video_path = os.path.join(output_dir, f"{video_id}.mp4") + if not os.path.exists(video_path): + raise FileNotFoundError(f"视频文件未找到: {video_path}") + return video_path diff --git a/backend/app/downloaders/xiaoyuzhoufm_download.py b/backend/app/downloaders/xiaoyuzhoufm_download.py new file mode 100644 index 0000000000000000000000000000000000000000..e0cc85c32dad5c0b19043394c43069ca62f67eeb --- /dev/null +++ b/backend/app/downloaders/xiaoyuzhoufm_download.py @@ -0,0 +1,25 @@ +from typing import Union, Optional + +import requests + +from app.downloaders.base import Downloader +from app.enmus.note_enums import DownloadQuality +from app.models.audio_model import AudioDownloadResult + +url='https://www.xiaoyuzhoufm.com/_next/data/5Pvt_oGntgdyBD_XgwBaB/podcast/62382c1103bea1ebfffa1c00.json?id=62382c1103bea1ebfffa1c00' +header ={ + 'user-agent':'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/137.0.0.0 Safari/537.36' +} + +response = requests.get(url, headers=header) +print(response.json()) + +class Xiaoyuzhoufm_download(Downloader): + def download( + self, + video_url: str, + output_dir: Union[str, None] = None, + quality: DownloadQuality = "fast", + need_video:Optional[bool]=False + ) -> AudioDownloadResult: + pass \ No newline at end of file diff --git a/backend/app/downloaders/youtube_downloader.py b/backend/app/downloaders/youtube_downloader.py new file mode 100644 index 0000000000000000000000000000000000000000..1ed185388021f65cddd33ec98deb9364c710db53 --- /dev/null +++ b/backend/app/downloaders/youtube_downloader.py @@ -0,0 +1,259 @@ +import os +import logging +import tempfile +from abc import ABC +from typing import Union, Optional, List + +import yt_dlp + +from app.downloaders.base import Downloader, DownloadQuality +from app.downloaders.youtube_subtitle import YouTubeSubtitleFetcher +from app.models.notes_model import AudioDownloadResult +from app.models.transcriber_model import TranscriptResult +from app.services.cookie_manager import CookieConfigManager +from app.services.proxy_config_manager import ProxyConfigManager +from app.utils.path_helper import get_data_dir +from app.utils.url_parser import extract_video_id + +logger = logging.getLogger(__name__) + + +def _apply_proxy(ydl_opts: dict) -> dict: + """YouTube 在国内需要代理。配置了全局代理就塞进 yt-dlp opts。""" + proxy = ProxyConfigManager().get_proxy_url() + if proxy: + ydl_opts['proxy'] = proxy + logger.info(f"yt-dlp 走代理: {proxy}") + return ydl_opts + + +def _apply_youtube_extractor_args(ydl_opts: dict) -> dict: + """YouTube player_client 选择。 + + 默认不再覆盖、交给 yt-dlp 的内置策略: + 早期为绕开 SSAP 实验(issue #12482)硬编码过 ['tv', 'web_safari'], + 但 YouTube 后来对 tv 客户端做「全量 DRM」实验(issue #12563),命中的会话 + 所有视频都报 "This video is DRM protected";而 web 系客户端需要 JS runtime + (deno)解 n challenge,装好后 yt-dlp 默认客户端列表即可正常取流。 + 硬编码的客户端列表会随 YouTube 风控变化反复失效,不如跟随 yt-dlp 升级。 + + 如需临时指定,可设环境变量 YT_PLAYER_CLIENT(逗号分隔),如 + YT_PLAYER_CLIENT=web_safari,android_vr。 + """ + clients = os.getenv('YT_PLAYER_CLIENT', '').strip() + if clients: + ydl_opts.setdefault('extractor_args', {}) + ydl_opts['extractor_args'].setdefault('youtube', {}) + ydl_opts['extractor_args']['youtube']['player_client'] = [ + c.strip() for c in clients.split(',') if c.strip() + ] + return ydl_opts + + +class YoutubeDownloader(Downloader, ABC): + def __init__(self): + + super().__init__() + self._cookie_mgr = CookieConfigManager() + self._cookie = self._cookie_mgr.get('youtube') + # 优先级:浏览器实时 cookies > 粘贴的 cookie 字符串。 + # 配了浏览器就走 yt-dlp `cookiesfrombrowser`,能避开 YouTube 的会话轮换风控。 + self._browser = self._cookie_mgr.get_browser('youtube') + self._cookiefile = None if self._browser else self._write_netscape_cookie_file() + + def _write_netscape_cookie_file(self) -> Optional[str]: + """将 YouTube Cookie 写入 Netscape 格式临时文件,供 yt-dlp cookiefile 使用。 + + 没有 Cookie 时返回 None;YouTube 现在没 Cookie 基本会被拦在「Sign in to confirm you're not a bot」。 + """ + if not self._cookie: + logger.warning("YouTube Cookie 未配置,下载可能会被风控为机器人") + return None + lines = ["# Netscape HTTP Cookie File\n"] + for pair in self._cookie.split("; "): + if "=" in pair: + key, value = pair.split("=", 1) + lines.append(f".youtube.com\tTRUE\t/\tFALSE\t0\t{key}\t{value}\n") + tmp = tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False, encoding='utf-8') + tmp.writelines(lines) + tmp.close() + logger.info("已生成 YouTube Netscape Cookie 文件: %s (条目: %d)", tmp.name, len(lines) - 1) + return tmp.name + + def download( + self, + video_url: str, + output_dir: Union[str, None] = None, + quality: DownloadQuality = "fast", + need_video: Optional[bool] = False, + skip_download: bool = False, + ) -> AudioDownloadResult: + if output_dir is None: + output_dir = get_data_dir() + if not output_dir: + output_dir = self.cache_data + os.makedirs(output_dir, exist_ok=True) + + output_path = os.path.join(output_dir, "%(id)s.%(ext)s") + + ydl_opts = { + 'format': 'bestaudio[ext=m4a]/bestaudio/best', + 'outtmpl': output_path, + 'noplaylist': True, + 'quiet': False, + } + + if skip_download: + ydl_opts['skip_download'] = True + + _apply_proxy(ydl_opts) + _apply_youtube_extractor_args(ydl_opts) + if self._browser: + # (browser_name,) 形式即可;profile/keyring/container 留默认 + ydl_opts['cookiesfrombrowser'] = (self._browser,) + logger.info(f"YouTube 使用 cookies-from-browser: {self._browser}") + elif self._cookiefile: + ydl_opts['cookiefile'] = self._cookiefile + + try: + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + info = ydl.extract_info(video_url, download=not skip_download) + video_id = info.get("id") + title = info.get("title") + duration = info.get("duration", 0) + cover_url = info.get("thumbnail") + ext = info.get("ext", "m4a") + audio_path = os.path.join(output_dir, f"{video_id}.{ext}") + + return AudioDownloadResult( + file_path=audio_path, + title=title, + duration=duration, + cover_url=cover_url, + platform="youtube", + video_id=video_id, + raw_info={'tags': info.get('tags')}, + video_path=None, + ) + except Exception as exc: + # DRM / 反爬 / 格式不可用等情况下 yt-dlp 拉不动;只要本次仅需要 metadata + # (即字幕路径,skip_download=True),就退到 YouTube oEmbed 兜底拿标题+封面, + # 让流程能继续走总结。需要下载音视频时只能向上抛。 + if not skip_download: + raise + logger.warning(f"yt-dlp 获取元数据失败,回退 oEmbed: {exc}") + return self._fallback_metadata(video_url) + + def _fallback_metadata(self, video_url: str) -> AudioDownloadResult: + """yt-dlp 失败时的兜底:用 YouTube 公开的 oEmbed 接口拿基础 metadata。 + + 只能拿到 title / thumbnail / author 这几样;duration / tags 拿不到,做空值处理。 + DRM、bot 拦截等都不影响 oEmbed。 + """ + import requests + + video_id = extract_video_id(video_url, "youtube") or "" + title = video_id or "YouTube 视频" + cover = f"https://i.ytimg.com/vi/{video_id}/hqdefault.jpg" if video_id else "" + try: + proxies = None + proxy = ProxyConfigManager().get_proxy_url() + if proxy: + proxies = {"http": proxy, "https": proxy} + resp = requests.get( + "https://www.youtube.com/oembed", + params={"url": video_url, "format": "json"}, + proxies=proxies, + timeout=10, + ) + resp.raise_for_status() + data = resp.json() + if data.get("title"): + title = data["title"] + if data.get("thumbnail_url"): + cover = data["thumbnail_url"] + logger.info(f"oEmbed 兜底成功:title={title}") + except Exception as e: + logger.warning(f"oEmbed 兜底也失败,使用最小元数据:{e}") + + return AudioDownloadResult( + file_path="", # 没下载音视频文件 + title=title, + duration=0, # oEmbed 不返回时长 + cover_url=cover, + platform="youtube", + video_id=video_id, + raw_info={"tags": []}, # oEmbed 不返回标签 + video_path=None, + ) + + def download_video( + self, + video_url: str, + output_dir: Union[str, None] = None, + ) -> str: + """ + 下载视频,返回视频文件路径 + """ + if output_dir is None: + output_dir = get_data_dir() + video_id = extract_video_id(video_url, "youtube") + video_path = os.path.join(output_dir, f"{video_id}.mp4") + if os.path.exists(video_path): + return video_path + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, "%(id)s.%(ext)s") + + ydl_opts = { + # 这里下载的视频只用于截图网格/视频理解抽帧,720p 足够: + # 不设上限的话 bestvideo 会选 4K AV1(动辄 300MB+,下载和 ffmpeg + # 解码抽帧都极慢)。优先 avc1(解码远快于 av01),同高度再退 av01。 + 'format': ( + 'bestvideo[height<=720][vcodec^=avc1]+bestaudio[ext=m4a]' + '/bestvideo[height<=720][ext=mp4]+bestaudio[ext=m4a]' + '/best[height<=720][ext=mp4]/best[ext=mp4]' + ), + 'outtmpl': output_path, + 'noplaylist': True, + 'quiet': False, + 'merge_output_format': 'mp4', # 确保合并成 mp4 + } + + _apply_proxy(ydl_opts) + _apply_youtube_extractor_args(ydl_opts) + if self._browser: + # (browser_name,) 形式即可;profile/keyring/container 留默认 + ydl_opts['cookiesfrombrowser'] = (self._browser,) + logger.info(f"YouTube 使用 cookies-from-browser: {self._browser}") + elif self._cookiefile: + ydl_opts['cookiefile'] = self._cookiefile + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + info = ydl.extract_info(video_url, download=True) + video_id = info.get("id") + video_path = os.path.join(output_dir, f"{video_id}.mp4") + + if not os.path.exists(video_path): + raise FileNotFoundError(f"视频文件未找到: {video_path}") + + return video_path + + def download_subtitles(self, video_url: str, output_dir: str = None, + langs: List[str] = None) -> Optional[TranscriptResult]: + """ + 通过 YouTube InnerTube API 直接获取字幕(优先人工字幕,其次自动生成)。 + 比 yt_dlp 方式更轻量,无需写临时文件到磁盘。 + + :param video_url: 视频链接 + :param output_dir: 未使用(保留接口兼容) + :param langs: 优先语言列表 + :return: TranscriptResult 或 None + """ + if langs is None: + langs = ['zh-Hans', 'zh', 'zh-CN', 'zh-TW', 'en', 'en-US', 'ja'] + + video_id = extract_video_id(video_url, "youtube") + fetcher = YouTubeSubtitleFetcher() + print( + f"尝试获取字幕,video_id={video_id}, langs={langs}" + ) + return fetcher.fetch_subtitles(video_id, langs) diff --git a/backend/app/downloaders/youtube_subtitle.py b/backend/app/downloaders/youtube_subtitle.py new file mode 100644 index 0000000000000000000000000000000000000000..81a559a45c625fb9640a415a70d4211bcc7bbd16 --- /dev/null +++ b/backend/app/downloaders/youtube_subtitle.py @@ -0,0 +1,113 @@ +""" +通过 youtube-transcript-api 获取 YouTube 字幕。 +优先人工字幕,其次自动生成字幕。不依赖 yt_dlp,无需下载任何文件。 +""" + +from typing import Optional, List + +from youtube_transcript_api import YouTubeTranscriptApi + +from app.models.transcriber_model import TranscriptResult, TranscriptSegment +from app.services.proxy_config_manager import ProxyConfigManager +from app.utils.logger import get_logger + +logger = get_logger(__name__) + + +class YouTubeSubtitleFetcher: + """通过 youtube-transcript-api 获取 YouTube 字幕。""" + + def __init__(self): + # 配了全局代理就给 youtube-transcript-api 套一个带 proxies 的 requests.Session, + # 否则国内拉字幕同样会超时。代理未配置时退回默认无代理客户端。 + proxy = ProxyConfigManager().get_proxy_url() + if proxy: + try: + import requests + session = requests.Session() + session.proxies = {"http": proxy, "https": proxy} + self._api = YouTubeTranscriptApi(http_client=session) + logger.info(f"YouTube 字幕走代理: {proxy}") + except Exception as e: + logger.warning(f"为 youtube-transcript-api 注入代理失败,回退无代理: {e}") + self._api = YouTubeTranscriptApi() + else: + self._api = YouTubeTranscriptApi() + + def fetch_subtitles( + self, + video_id: str, + langs: Optional[List[str]] = None, + ) -> Optional[TranscriptResult]: + if langs is None: + langs = ["zh-Hans", "zh", "zh-CN", "zh-TW", "en", "en-US", "ja"] + + try: + # 1. 列出所有可用字幕 + transcript_list = self._api.list(video_id) + + available = [] + for t in transcript_list: + available.append( + f"{t.language_code}({'auto' if t.is_generated else 'manual'})" + ) + logger.info(f"可用字幕轨道: {', '.join(available)}") + + # 2. 按优先级查找:先人工字幕,再自动字幕 + transcript = None + try: + transcript = transcript_list.find_manually_created_transcript(langs) + logger.info(f"选中人工字幕: {transcript.language_code} ({transcript.language})") + except Exception: + try: + transcript = transcript_list.find_generated_transcript(langs) + logger.info(f"选中自动字幕: {transcript.language_code} ({transcript.language})") + except Exception: + # 都没匹配,取第一个可用的 + for t in transcript_list: + transcript = t + source = "auto" if t.is_generated else "manual" + logger.info(f"使用首个可用字幕: {t.language_code} ({source})") + break + + if not transcript: + logger.info(f"YouTube 视频 {video_id} 没有任何可用字幕") + return None + + # 3. 获取字幕内容 + fetched = transcript.fetch() + segments = [] + for snippet in fetched: + text = snippet.get("text", "").strip() if isinstance(snippet, dict) else str(snippet).strip() + if not text: + continue + start = snippet.get("start", 0) if isinstance(snippet, dict) else 0 + duration = snippet.get("duration", 0) if isinstance(snippet, dict) else 0 + segments.append(TranscriptSegment( + start=float(start), + end=float(start) + float(duration), + text=text, + )) + + if not segments: + logger.warning(f"YouTube 字幕内容为空: {video_id}") + return None + + full_text = " ".join(seg.text for seg in segments) + logger.info(f"成功获取 YouTube 字幕,共 {len(segments)} 段") + + return TranscriptResult( + language=transcript.language_code, + full_text=full_text, + segments=segments, + raw={ + "source": "youtube_transcript_api", + "language": transcript.language, + "language_code": transcript.language_code, + "is_generated": transcript.is_generated, + }, + ) + + except Exception as e: + logger.warning(f"YouTube 字幕获取失败: {e}") + return None diff --git a/backend/app/enmus/exception.py b/backend/app/enmus/exception.py new file mode 100644 index 0000000000000000000000000000000000000000..f4de841bf3f2fafd6df300ac564306f2112f3e07 --- /dev/null +++ b/backend/app/enmus/exception.py @@ -0,0 +1,21 @@ +import enum + + +class ProviderErrorEnum(enum.Enum): + CONNECTION_TEST_FAILED = (200101, "供应商连接测试失败") + SAVE_FAILED = (200102, "供应商保存失败") + CREATE_FAILED = (200103, "供应商创建失败") + NOT_FOUND = (200104, "供应商不存在/未保存") + WRONG_PARAMETER = (200105, "API / API 地址不正确") + UNKNOW_ERROR = (200106, "未知错误") + + def __init__(self, code, message): + self.code = code + self.message = message + +class NoteErrorEnum(enum.Enum): + PLATFORM_NOT_SUPPORTED = (300101 ,"选择的平台不受支持") + + def __init__(self, code, message): + self.code = code + self.message = message \ No newline at end of file diff --git a/backend/app/enmus/note_enums.py b/backend/app/enmus/note_enums.py new file mode 100644 index 0000000000000000000000000000000000000000..be5d3c1c0d053435f8c78d81688a90398594e9fd --- /dev/null +++ b/backend/app/enmus/note_enums.py @@ -0,0 +1,7 @@ +import enum + + +class DownloadQuality(str, enum.Enum): + fast = "fast" + medium = "medium" + slow = "slow" diff --git a/backend/app/enmus/task_status_enums.py b/backend/app/enmus/task_status_enums.py new file mode 100644 index 0000000000000000000000000000000000000000..f83daa11b1938779f934d4957f8508259af9f23d --- /dev/null +++ b/backend/app/enmus/task_status_enums.py @@ -0,0 +1,28 @@ +import enum + + +class TaskStatus(str, enum.Enum): + PENDING = "PENDING" + PARSING = "PARSING" + DOWNLOADING = "DOWNLOADING" + TRANSCRIBING = "TRANSCRIBING" + SUMMARIZING = "SUMMARIZING" + FORMATTING = "FORMATTING" + SAVING = "SAVING" + SUCCESS = "SUCCESS" + FAILED = "FAILED" + + @classmethod + def description(cls, status): + desc_map = { + cls.PENDING: "排队中", + cls.PARSING: "解析链接", + cls.DOWNLOADING: "下载中", + cls.TRANSCRIBING: "转录中", + cls.SUMMARIZING: "总结中", + cls.FORMATTING: "格式化中", + cls.SAVING: "保存中", + cls.SUCCESS: "完成", + cls.FAILED: "失败", + } + return desc_map.get(status, "未知状态") diff --git a/backend/app/exceptions/__init__.py b/backend/app/exceptions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backend/app/exceptions/biz_exception.py b/backend/app/exceptions/biz_exception.py new file mode 100644 index 0000000000000000000000000000000000000000..d7cc1b5f71b51351f8cd3b7f29bf0795b81316f8 --- /dev/null +++ b/backend/app/exceptions/biz_exception.py @@ -0,0 +1,6 @@ +# exceptions/biz_exception.py + +class BizException(Exception): + def __init__(self, code: int, message: str = "业务异常"): + self.code = code + self.message = message \ No newline at end of file diff --git a/backend/app/exceptions/exception_handlers.py b/backend/app/exceptions/exception_handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..6bd10532c1cfebb88e69f01aa5791f502ca40a73 --- /dev/null +++ b/backend/app/exceptions/exception_handlers.py @@ -0,0 +1,33 @@ +# middlewares/exception_handler.py + +from fastapi import Request +from fastapi import FastAPI + +from app.enmus.exception import NoteErrorEnum +from app.exceptions.biz_exception import BizException +from app.exceptions.note import NoteError +from app.exceptions.provider import ProviderError +from app.utils.logger import get_logger +from app.utils.response import ResponseWrapper as R +import traceback + +logger = get_logger(__name__) + +def register_exception_handlers(app: FastAPI): + @app.exception_handler(BizException) + async def biz_exception_handler(request: Request, exc: BizException): + logger.error(f"BizException: {exc.code} - {exc.message}") + return R.error(code=exc.code, msg=str(exc.message)) + @app.exception_handler(NoteError) + async def note_exception_handler(request: Request, exc: NoteError): + logger.error(f"NoteError: {exc.code} - {exc.message}") + return R.error(code=exc.code, msg=str(exc.message)) + @app.exception_handler(ProviderError) + async def provider_exception_handler(request: Request, exc: ProviderError): + logger.error(f"供应商模块错误: {exc.code} - {exc.message}") + return R.error(code=exc.code, msg=str(exc.message)) + + @app.exception_handler(Exception) + async def general_exception_handler(request: Request, exc: Exception): + logger.error(f"系统异常: {str(exc)}\n{traceback.format_exc()}") + return R.error(code=500000, msg="系统异常") \ No newline at end of file diff --git a/backend/app/exceptions/note.py b/backend/app/exceptions/note.py new file mode 100644 index 0000000000000000000000000000000000000000..ce29ab72eba8ffadeb4875adc0560e16e74ef84b --- /dev/null +++ b/backend/app/exceptions/note.py @@ -0,0 +1,9 @@ +# exceptions.py +from app.enmus.exception import ProviderErrorEnum + + +class NoteError(Exception): + def __init__(self, message: str,code: ProviderErrorEnum) -> None: + super().__init__(message) + self.code=code + self.message = message \ No newline at end of file diff --git a/backend/app/exceptions/provider.py b/backend/app/exceptions/provider.py new file mode 100644 index 0000000000000000000000000000000000000000..53a511ea3ed64152650ff36ceaf54246d286cb8d --- /dev/null +++ b/backend/app/exceptions/provider.py @@ -0,0 +1,12 @@ +# exceptions.py +from app.enmus.exception import ProviderErrorEnum + + +class ProviderError(Exception): + def __init__(self, message: str,code: ProviderErrorEnum) -> None: + super().__init__(message) + self.code=code + self.message = message + + + diff --git a/backend/app/gpt/__init__.py b/backend/app/gpt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backend/app/gpt/base.py b/backend/app/gpt/base.py new file mode 100644 index 0000000000000000000000000000000000000000..058ea6c40325acdcc2fb5980a16fb0e077936333 --- /dev/null +++ b/backend/app/gpt/base.py @@ -0,0 +1,17 @@ +from abc import ABC,abstractmethod + +from app.models.gpt_model import GPTSource + + +class GPT(ABC): + def summarize(self, source:GPTSource )->str: + ''' + + :param source: + :return: + ''' + pass + def create_messages(self, segments:list,**kwargs)->list: + pass + def list_models(self): + pass \ No newline at end of file diff --git a/backend/app/gpt/deepseek_gpt.py b/backend/app/gpt/deepseek_gpt.py new file mode 100644 index 0000000000000000000000000000000000000000..1c230b8509cdf6918500e7d92326baeefee20b5d --- /dev/null +++ b/backend/app/gpt/deepseek_gpt.py @@ -0,0 +1,59 @@ +from typing import List +from app.gpt.base import GPT +from app.utils.openai_client import build_openai_client +from app.gpt.prompt import BASE_PROMPT, AI_SUM, SCREENSHOT +from app.gpt.utils import fix_markdown +from app.models.gpt_model import GPTSource +from app.models.transcriber_model import TranscriptSegment +from datetime import timedelta + + +class DeepSeekGPT(GPT): + def __init__(self): + from os import getenv + self.api_key = getenv("DEEP_SEEK_API_KEY") + self.base_url = getenv("DEEP_SEEK_API_BASE_URL") + self.model=getenv('DEEP_SEEK_MODEL') + print(self.model) + self.client = build_openai_client(self.api_key, self.base_url, key_label="DeepSeek 的 API Key") + self.screenshot = False + + def _format_time(self, seconds: float) -> str: + return str(timedelta(seconds=int(seconds)))[2:] # e.g., 03:15 + + def _build_segment_text(self, segments: List[TranscriptSegment]) -> str: + return "\n".join( + f"{self._format_time(seg.start)} - {seg.text.strip()}" + for seg in segments + ) + + def ensure_segments_type(self, segments) -> List[TranscriptSegment]: + return [ + TranscriptSegment(**seg) if isinstance(seg, dict) else seg + for seg in segments + ] + + def create_messages(self, segments: List[TranscriptSegment], title: str,tags:str): + content = BASE_PROMPT.format( + video_title=title, + segment_text=self._build_segment_text(segments), + tags=tags + ) + if self.screenshot: + print(":需要截图") + content += SCREENSHOT + print(content) + return [{"role": "user", "content": content + AI_SUM}] + + def summarize(self, source: GPTSource) -> str: + self.screenshot = source.screenshot + source.segment = self.ensure_segments_type(source.segment) + messages = self.create_messages(source.segment, source.title,source.tags) + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + temperature=0.7 + ) + return response.choices[0].message.content.strip() + + diff --git a/backend/app/gpt/gpt_factory.py b/backend/app/gpt/gpt_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..8e8b60eb8a1d3e9e06f576e9aa61daf3d9a58d3b --- /dev/null +++ b/backend/app/gpt/gpt_factory.py @@ -0,0 +1,13 @@ +from openai import OpenAI + +from app.gpt.base import GPT +from app.gpt.provider.OpenAI_compatible_provider import OpenAICompatibleProvider +from app.gpt.universal_gpt import UniversalGPT +from app.models.model_config import ModelConfig + + +class GPTFactory: + @staticmethod + def from_config(config: ModelConfig) -> GPT: + client = OpenAICompatibleProvider(api_key=config.api_key, base_url=config.base_url).get_client + return UniversalGPT(client=client, model=config.model_name) \ No newline at end of file diff --git a/backend/app/gpt/openai_gpt.py b/backend/app/gpt/openai_gpt.py new file mode 100644 index 0000000000000000000000000000000000000000..1b85af114b329c1a52e19359e3751f7cdbb3bd0f --- /dev/null +++ b/backend/app/gpt/openai_gpt.py @@ -0,0 +1,69 @@ +from typing import List +from app.gpt.base import GPT +from openai import OpenAI +from app.gpt.prompt import BASE_PROMPT, AI_SUM, SCREENSHOT, LINK +from app.gpt.provider.OpenAI_compatible_provider import OpenAICompatibleProvider +from app.gpt.utils import fix_markdown +from app.models.gpt_model import GPTSource +from app.models.transcriber_model import TranscriptSegment +from datetime import timedelta + + +class OpenaiGPT(GPT): + def __init__(self): + from os import getenv + self.api_key = getenv("OPENAI_API_KEY") + self.base_url = getenv("OPENAI_API_BASE_URL") + self.model=getenv('OPENAI_MODEL') + print(self.model) + self.client = OpenAICompatibleProvider(api_key=self.api_key, base_url=self.base_url) + self.screenshot = False + self.link=False + + def _format_time(self, seconds: float) -> str: + return str(timedelta(seconds=int(seconds)))[2:] # e.g., 03:15 + + def _build_segment_text(self, segments: List[TranscriptSegment]) -> str: + return "\n".join( + f"{self._format_time(seg.start)} - {seg.text.strip()}" + for seg in segments + ) + + def ensure_segments_type(self, segments) -> List[TranscriptSegment]: + return [ + TranscriptSegment(**seg) if isinstance(seg, dict) else seg + for seg in segments + ] + + def create_messages(self, segments: List[TranscriptSegment], title: str,tags:str): + content = BASE_PROMPT.format( + video_title=title, + segment_text=self._build_segment_text(segments), + tags=tags + ) + if self.screenshot: + print(":需要截图") + content += SCREENSHOT + if self.link: + print(":需要链接") + content += LINK + + print(content) + return [{"role": "user", "content": content + AI_SUM}] + def list_models(self): + return self.client.list_models() + def summarize(self, source: GPTSource) -> str: + self.screenshot = source.screenshot + self.link = source.link + source.segment = self.ensure_segments_type(source.segment) + messages = self.create_messages(source.segment, source.title,source.tags) + response = self.client.chat( + model=self.model, + messages=messages, + temperature=0.7 + ) + return response.choices[0].message.content.strip() + +if __name__ == '__main__': + gpt = OpenaiGPT() + print(gpt.list_models()) diff --git a/backend/app/gpt/prompt.py b/backend/app/gpt/prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..b21b2bec6109c3a9e2574dc45ed8522560bcd589 --- /dev/null +++ b/backend/app/gpt/prompt.py @@ -0,0 +1,174 @@ +BASE_PROMPT = ''' +# 角色 + +你是一名专业的视频笔记助手(Video Note Assistant),负责将视频转录内容整理为 +结构化、高质量、可直接阅读的 Markdown 学习笔记。 + +--- + +# 语言规则 + +- 笔记必须使用 **中文** 撰写。 +- 技术术语、编程语言、框架名称、品牌名称、产品名称、人名保留 **英文** 原文, + 不要强行翻译成中文。 +- 表达保持专业、简洁、准确。 + +--- + +# 视频信息 + +视频标题: +{video_title} + +视频标签: +{tags} + +--- + +# 输入格式 + +视频分段的格式为「开始时间 - 内容」,例如: + +00:35 - React Hooks 用于管理组件状态 +01:20 - useEffect 可以处理副作用 + +视频分段如下: + +--- +{segment_text} +--- + +# 输出要求 + +- 仅返回最终的 **Markdown 笔记内容**。 +- **禁止**输出解释、前言或类似「以下是整理后的笔记」的开场白。 +- **禁止**将整个结果包裹在代码块中(例如:```` ```markdown ````,```` ``` ````)。 +- 注意区分「标题」与「列表主点」:想表达**章节标题**时不要写成 `1. **内容**` + (会被误解析为有序列表),应写成 `## 1. 内容`;而章节内部的**列表主点** + 使用有序列表 `1. **要点**` 是正确且推荐的(见下方正文层级结构)。 + +--- + +# 内容组织规则 + +围绕主题组织结构,而不是机械地按时间顺序逐句整理、逐字转录。 +目标:把内容整理成真正可阅读的学习笔记。请遵循以下原则: + +1. **完整信息**:记录尽可能多的相关细节,确保内容全面。 +2. **保留关键内容**:核心概念、关键结论、方法论、操作步骤、实战经验、 + 案例分析、最佳实践、重要事实与示例、风险提示与注意事项必须保留。 +3. **去除无价值内容**:删除寒暄、广告、口头禅、求关注求点赞 + (如「兄弟们点个关注」「感谢大家支持」)以及与主题无关的言论。 +4. **合并重复内容**:同一观点在多个时间段重复出现时,保留最完整的版本, + 合并补充信息,避免重复记录。 +5. **长视频压缩**:内容反复出现时只保留新增信息,合并相似说明, + 保留最具代表性的案例,删除重复举例。 +6. **可读布局**:必要时使用项目符号,段落简短、结构清晰、易于阅读。 + (如果「额外重要的任务」中的风格有特殊格式需求,以风格要求优先) + +--- + +# 正文层级结构 + +主内容区域统一遵循,整篇笔记保持一致: + +- 章节一律用 `## ` 二级标题,正文中不要使用 `###` 等更深的标题 +- 章节内容只有**两层**时(章节 → 要点):直接用 `- ` 无序列表,例如: + + ## 章节标题 + - 要点一 + - 要点二 + +- 章节内容有**三层及以上**时(章节 → 主要点 → 子条目):主要点用有序列表 + `1. **主要点**`(加粗),子条目用缩进三个空格的 `- `,例如: + + ## 子代理擅长的场景 + 1. **研究型任务** + - 只需要答案,不需要探索过程 + - 示例:在陌生代码库中研究认证如何工作 + 2. **代码审查** + - 审查子代理在独立上下文中运行 + - 避免主线程"记忆污染",确保客观反馈 + +--- + +# 代码规则 + +视频中出现代码示例时,必须使用 Markdown 代码块呈现: + +- 标注语言类型(如 ```` ```ts ````、```` ```python ````) +- 保留代码结构与缩进 +- 不要把代码写成普通文本 + +--- + +# 数学公式规则 + +视频中提及的数学公式必须保留,并以 LaTeX 语法呈现,适合 Markdown 渲染: + +- 行内公式:`$E = mc^2$` +- 块级公式:用 `$$` 包裹独立成段 + +禁止 `E = mc²`、`x^2+y^2` 这类非 LaTeX 的裸文本写法。 + +--- + +请始终遵循以上全部规则。 + +额外重要的任务如下(每一个都必须严格完成): + +''' + + +LINK = ''' +- **原片跳转(重要)**:为每个 `##` 主章节标题追加该段起始时间标记,格式严格为 + `*Content-[mm:ss]`(mm:ss 为两位分:两位秒),且必须「标题在前、标记在后」写在 + 同一行,例如:`## AI 的发展史 *Content-[01:23]`。 + 禁止把标记写在标题之前,禁止让标记单独成行。 +''' + +AI_SUM = ''' +- **AI 总结**:在笔记末尾追加二级标题 `## AI 总结`,用中文总结视频主题、 + 核心观点、关键结论与实践建议,长度控制在 150~300 字。 +''' + +SCREENSHOT = ''' +- **原片截图**:根据转写文本中的时间点选择截图位置。当章节涉及 UI 演示、产品操作流程、软件界面讲解、图表分析、 + 架构图说明、代码讲解、实时调试过程、前后效果对比等视觉内容时,在该章节末尾 + 插入截图标记,格式严格为 `*Screenshot-[mm:ss]`。 + 即使没有收到视频画面,也要根据转写文本中的时间点选择 2~4 个最有代表性的截图位置。 + 每个章节最多一个截图标记;没有视觉价值时不要添加,不允许滥用。 +''' + +MERGE_PROMPT = ''' +# 角色 + +你是一名 Markdown 笔记合并助手。你将收到多个来自同一视频、按时间先后排列的 +笔记片段,请合并为一份完整笔记。 + +# 合并规则 + +1. **只做合并与整理**:禁止发明新内容、补充不存在的信息、推测作者意图。 +2. **去重**:多个片段内容重复时,保留信息最完整的版本,合并补充内容, + 删除重复描述。 +3. **标题合并**:多个片段出现相同或高度相似的 `##` 章节标题时,必须合并为 + 一个章节,禁止出现重复标题。 +4. **时间顺序**:同一章节内的内容按时间标记从早到晚排列。 +5. **保留标记**:必须原样保留所有 `*Content-[mm:ss]` 与 `*Screenshot-[mm:ss]` + 标记,不要删除或改写格式;章节标题上的时间标记保持「标题在前、标记在后」 + 写在同一行。 +6. **保留 Markdown 结构**:标题层级、表格、列表、引用、代码块、LaTeX 公式 + 必须原样保留。 +7. **目录**:如果多个片段都包含 `## 目录`,只在笔记开头保留一份,并整合为 + 覆盖全部章节的完整目录;目录条目内禁止出现 `#`/`##` 等标题标记。 +8. **AI 总结**:如果多个片段都包含 `## AI 总结`,只保留一个放在笔记末尾, + 并重新整合所有片段的总结内容。 + +# 语言 + +中文输出,技术术语、品牌名称、人名保留英文。 + +# 输出 + +仅输出最终 Markdown 笔记。不要输出解释,不要将整个结果包裹在代码块中。 +''' diff --git a/backend/app/gpt/prompt_builder.py b/backend/app/gpt/prompt_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..8061e9ae631b2094b0830e76de142cb19278eb72 --- /dev/null +++ b/backend/app/gpt/prompt_builder.py @@ -0,0 +1,138 @@ +from app.gpt.prompt import BASE_PROMPT + +note_formats = [ + {'label': '目录', 'value': 'toc'}, + {'label': '原片跳转', 'value': 'link'}, + {'label': '原片截图', 'value': 'screenshot'}, + {'label': 'AI总结', 'value': 'summary'} +] + +note_styles = [ + {'label': '精简', 'value': 'minimal'}, + {'label': '详细', 'value': 'detailed'}, + {'label': '学术', 'value': 'academic'}, + {"label": '教程',"value": 'tutorial', }, + {'label': '小红书', 'value': 'xiaohongshu'}, + {'label': '生活向', 'value': 'life_journal'}, + {'label': '任务导向', 'value': 'task_oriented'}, + {'label': '商业风格', 'value': 'business'}, + {'label': '会议纪要', 'value': 'meeting_minutes'} +] + + +# 生成 BASE_PROMPT 函数 +def generate_base_prompt(title, segment_text, tags, _format=None, style=None, extras=None): + # 生成 Base Prompt 开头部分 + prompt = BASE_PROMPT.format( + video_title=title, + segment_text=segment_text, + tags=tags + ) + + # 添加用户选择的格式 + if _format: + prompt += "\n" + "\n".join([get_format_function(f) for f in _format]) + + # 根据用户选择的笔记风格添加描述 + if style: + prompt += "\n" + get_style_format(style) + + # 添加额外内容 + if extras: + prompt += f"\n{extras}" + return prompt + + +# 获取格式函数 +def get_format_function(format_type): + format_map = { + 'toc': get_toc_format, + 'link': get_link_format, + 'screenshot': get_screenshot_format, + 'summary': get_summary_format + } + return format_map.get(format_type, lambda: '')() + + +# 风格描述的处理 +def get_style_format(style): + style_map = { + 'minimal': '- **精简信息**: 仅记录最重要的内容,简洁明了。', + 'detailed': '- **详细记录**: 包含完整的内容和每个部分的详细讨论。需要尽可能多的记录视频内容,最好详细的笔记', + 'academic': '- **学术风格**: 适合学术报告,正式且结构化。', + 'xiaohongshu': '''- **小红书风格**: +### 擅长使用下面的爆款关键词: +好用到哭,大数据,教科书般,小白必看,宝藏,绝绝子神器,都给我冲,划重点,笑不活了,YYDS,秘方,我不允许,压箱底,建议收藏,停止摆烂,上天在提醒你,挑战全网,手把手,揭秘,普通女生,沉浸式,有手就能做吹爆,好用哭了,搞钱必看,狠狠搞钱,打工人,吐血整理,家人们,隐藏,高级感,治愈,破防了,万万没想到,爆款,永远可以相信被夸爆手残党必备,正确姿势 + +### 采用二极管标题法创作标题: +- 正面刺激法:产品或方法+只需1秒 (短期)+便可开挂(逆天效果) +- 负面刺激法:你不XXX+绝对会后悔 (天大损失) +(紧迫感) +利用人们厌恶损失和负面偏误的心理 + +### 写作技巧 +1. 使用惊叹号、省略号等标点符号增强表达力,营造紧迫感和惊喜感。 +2. **使用emoji表情符号,来增加文字的活力** +3. 采用具有挑战性和悬念的表述,引发读、“无敌者好奇心,例如“暴涨词汇量”了”、“拒绝焦虑”等 +4. 利用正面刺激和负面激,诱发读者的本能需求和动物基本驱动力,如“离离原上谱”、“你不知道的项目其实很赚”等 +5. 融入热点话题和实用工具,提高文章的实用性和时效性,如“2023年必知”、“chatGPT狂飙进行时”等 +6. 描述具体的成果和效果,强调标题中的关键词,使其更具吸引力,例如“英语底子再差,搞清这些语法你也能拿130+” +7. 使用吸引人的标题:''', + + 'life_journal': '- **生活向**: 记录个人生活感悟,情感化表达。', + 'task_oriented': '- **任务导向**: 强调任务、目标,适合工作和待办事项。', + 'business': '- **商业风格**: 适合商业报告、会议纪要,正式且精准。', + 'meeting_minutes': '- **会议纪要**: 适合商业报告、会议纪要,正式且精准。', + 'tutorial': '- **教程笔记**: 尽可能详细的记录教程,特别是关键点和一些重要的结论步骤。' + } + return style_map.get(style, '') + + +# 格式化输出内容 +def get_toc_format(): + return ''' +- **目录**: 在笔记开头生成目录,使用以下格式(二级标题 + 无序列表,可按需嵌套子项): + + ## 目录 + + - 章节标题一 + - 章节标题二 + - 小节标题 + + 唯一的硬性要求:目录条目(含子项)内**禁止出现 `#`/`##` 等标题标记** + (即不要写成 `- ## 章节标题`,否则条目会渲染得和正文标题一样大)。 + 目录条目内不需要插入原片跳转时间标记。 + ''' + + +def get_link_format(): + return ''' +- **原片跳转(重要)**: 为每个 `##` 主章节标题追加该段起始时间标记,格式严格为 + `*Content-[mm:ss]`(mm:ss 为两位分:两位秒),且必须「标题在前、标记在后」写在同一行。 + + 正确示例:`## AI 的发展史 *Content-[01:23]` + + 禁止把标记写在标题之前,禁止让标记单独成行,禁止省略方括号或使用其他时间格式。 + ''' + + +def get_screenshot_format(): + return ''' +- **原片截图**: 请根据转写文案里的时间点,在最能帮助用户理解的位置插入截图标记, + 必须严格按照以下格式返回,否则系统无法解析: + + 格式:`*Screenshot-[mm:ss]` + + 插入规则: + - 适合插入的内容:UI 演示、产品操作流程、软件界面讲解、图表分析、架构图说明、 + 代码讲解、实时调试过程、前后效果对比 + - 即使没有收到视频画面,也要根据转写文本中的时间点选择 2~4 个最有代表性的截图位置 + - 每个章节最多一个截图标记 + - 没有视觉价值时不要添加,不允许滥用 + ''' + + +def get_summary_format(): + return ''' +- **AI 总结**: 在笔记末尾追加二级标题 `## AI 总结`,用中文总结视频主题、核心观点、 + 关键结论与实践建议,长度控制在 150~300 字。 + ''' diff --git a/backend/app/gpt/provider/OpenAI_compatible_provider.py b/backend/app/gpt/provider/OpenAI_compatible_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..5ed279ee5f21a2f8e66c4bad80773e86d8eb5e48 --- /dev/null +++ b/backend/app/gpt/provider/OpenAI_compatible_provider.py @@ -0,0 +1,41 @@ +from typing import Optional, Union + +from app.utils.logger import get_logger +from app.utils.openai_client import build_openai_client + +logging= get_logger(__name__) +class OpenAICompatibleProvider: + def __init__(self, api_key: str, base_url: str, model: Union[str, None]=None): + # build_openai_client:注入全局代理 + 校验 api_key 非空 + self.client = build_openai_client(api_key, base_url, key_label="模型供应商的 API Key") + self.model = model + + @property + def get_client(self): + return self.client + + @staticmethod + def test_connection(api_key: str, base_url: str, model: str) -> bool: + """发一条最小化 chat completion 验证 key / base_url / model 三方都通。 + + 为什么不用 client.models.list(): + - 部分代理 / 自建供应商不实现 /v1/models(如某些 OpenAI 兼容网关) + - 部分供应商 key 在没有 inference 权限时 /v1/models 仍返回 200 + 最终用户跑的就是 chat.completions.create,所以直接测它最忠实。 + max_tokens=1 + temperature=0 让请求开销 < 0.0001 美元、延迟 < 2s。 + """ + try: + client = build_openai_client( + api_key, base_url, key_label="模型供应商的 API Key", timeout=15.0, + ) + client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": "ping"}], + max_tokens=1, + temperature=0, + ) + logging.info(f"连通性测试成功(model={model})") + return True + except Exception as e: + logging.warning(f"连通性测试失败(model={model}):{e}") + return False \ No newline at end of file diff --git a/backend/app/gpt/qwen_gpt.py b/backend/app/gpt/qwen_gpt.py new file mode 100644 index 0000000000000000000000000000000000000000..c39e879ab919bef7c54a1e024146716abdc5866b --- /dev/null +++ b/backend/app/gpt/qwen_gpt.py @@ -0,0 +1,63 @@ +from typing import List +from app.gpt.base import GPT +from openai import OpenAI +from app.gpt.prompt import BASE_PROMPT, AI_SUM, SCREENSHOT +from app.gpt.provider.OpenAI_compatible_provider import OpenAICompatibleProvider +from app.gpt.utils import fix_markdown +from app.models.gpt_model import GPTSource +from app.models.transcriber_model import TranscriptSegment +from datetime import timedelta + + +class QwenGPT(GPT): + def __init__(self): + from os import getenv + self.api_key = getenv("QWEN_API_KEY") + self.base_url = getenv("QWEN_API_BASE_URL") + self.model=getenv('QWEN_MODEL') + print(self.model) + self.client = OpenAICompatibleProvider(api_key=self.api_key, base_url=self.base_url) + self.screenshot = False + + def _format_time(self, seconds: float) -> str: + return str(timedelta(seconds=int(seconds)))[2:] # e.g., 03:15 + + def _build_segment_text(self, segments: List[TranscriptSegment]) -> str: + return "\n".join( + f"{self._format_time(seg.start)} - {seg.text.strip()}" + for seg in segments + ) + + def ensure_segments_type(self, segments) -> List[TranscriptSegment]: + return [ + TranscriptSegment(**seg) if isinstance(seg, dict) else seg + for seg in segments + ] + + def create_messages(self, segments: List[TranscriptSegment], title: str,tags:str): + content = BASE_PROMPT.format( + video_title=title, + segment_text=self._build_segment_text(segments), + tags=tags + ) + if self.screenshot: + print(":需要截图") + content += SCREENSHOT + print(content) + return [{"role": "user", "content": content + AI_SUM}] + def list_models(self): + return self.client.list_models() + def summarize(self, source: GPTSource) -> str: + self.screenshot = source.screenshot + source.segment = self.ensure_segments_type(source.segment) + messages = self.create_messages(source.segment, source.title,source.tags) + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + temperature=0.7 + ) + return response.choices[0].message.content.strip() + +if __name__ == '__main__': + gpt = QwenGPT() + print(gpt.list_models()) diff --git a/backend/app/gpt/request_chunker.py b/backend/app/gpt/request_chunker.py new file mode 100644 index 0000000000000000000000000000000000000000..7ffc74741f849df00a392ba37e4f3424de22edc5 --- /dev/null +++ b/backend/app/gpt/request_chunker.py @@ -0,0 +1,161 @@ +from dataclasses import dataclass +from typing import Callable, List, Optional + + +@dataclass +class ChunkPayload: + segments: list + image_urls: list + + +class RequestChunker: + def __init__(self, message_builder: Callable, max_bytes: int, size_estimator: Optional[Callable] = None): + self.message_builder = message_builder + self.max_bytes = max_bytes + self.size_estimator = size_estimator + + def estimate(self, messages) -> int: + if self.size_estimator: + return self.size_estimator(messages) + import json + return len(json.dumps(messages, ensure_ascii=False).encode("utf-8")) + + def _messages_size(self, segments, image_urls, **kwargs) -> int: + messages = self.message_builder(segments, image_urls, **kwargs) + return self.estimate(messages) + + def _get_text(self, segment) -> str: + if isinstance(segment, dict): + return segment.get("text", "") + return getattr(segment, "text", "") + + def _make_segment(self, segment, text: str): + if isinstance(segment, dict): + new_seg = dict(segment) + new_seg["text"] = text + return new_seg + if hasattr(segment, "__dict__"): + data = dict(segment.__dict__) + data["text"] = text + return type(segment)(**data) + return type(segment)(segment.start, segment.end, text) + + def _split_segment_to_fit(self, segment, **kwargs): + text = self._get_text(segment) + if not text: + raise ValueError("empty segment cannot be split") + lo, hi = 1, len(text) + best = None + while lo <= hi: + mid = (lo + hi) // 2 + candidate = self._make_segment(segment, text[:mid]) + size = self._messages_size([candidate], [], **kwargs) + if size <= self.max_bytes: + best = mid + lo = mid + 1 + else: + hi = mid - 1 + if best is None: + raise ValueError("single segment too large to fit request") + head = self._make_segment(segment, text[:best]) + tail = self._make_segment(segment, text[best:]) + return head, tail + + def chunk(self, segments: list, image_urls: list, **kwargs) -> List[ChunkPayload]: + segments = list(segments or []) + image_urls = list(image_urls or []) + if not segments and not image_urls: + return [] + + chunks: List[ChunkPayload] = [] + seg_idx = 0 + + while seg_idx < len(segments): + batch_segments = [] + while seg_idx < len(segments): + candidate = batch_segments + [segments[seg_idx]] + size = self._messages_size(candidate, [], **kwargs) + if size <= self.max_bytes: + batch_segments = candidate + seg_idx += 1 + continue + if not batch_segments: + head, tail = self._split_segment_to_fit(segments[seg_idx], **kwargs) + segments[seg_idx] = head + segments.insert(seg_idx + 1, tail) + continue + break + + if not batch_segments: + raise ValueError("unable to fit any content into chunk") + + chunks.append(ChunkPayload(segments=batch_segments, image_urls=[])) + + if not image_urls: + return chunks + + if not chunks: + chunks = [ChunkPayload(segments=[], image_urls=[])] + + if not segments: + for image in image_urls: + appended = False + for chunk in chunks[-1:]: + candidate_images = chunk.image_urls + [image] + if self._messages_size(chunk.segments, candidate_images, **kwargs) <= self.max_bytes: + chunk.image_urls = candidate_images + appended = True + break + + if appended: + continue + + if self._messages_size([], [image], **kwargs) > self.max_bytes: + raise ValueError("single image payload exceeds max_bytes") + chunks.append(ChunkPayload(segments=[], image_urls=[image])) + return chunks + + chunk_count = len(chunks) + total_images = len(image_urls) + for idx, image in enumerate(image_urls): + preferred_idx = min(chunk_count - 1, (idx * chunk_count) // total_images) + placed = False + + for chunk_idx in range(preferred_idx, len(chunks)): + chunk = chunks[chunk_idx] + candidate_images = chunk.image_urls + [image] + if self._messages_size(chunk.segments, candidate_images, **kwargs) <= self.max_bytes: + chunk.image_urls = candidate_images + placed = True + break + + if placed: + continue + + if self._messages_size([], [image], **kwargs) > self.max_bytes: + raise ValueError("single image payload exceeds max_bytes") + chunks.append(ChunkPayload(segments=[], image_urls=[image])) + + return chunks + + def group_texts_by_budget(self, texts: List[str], build_messages: Callable, **kwargs) -> List[List[str]]: + groups: List[List[str]] = [] + idx = 0 + while idx < len(texts): + group: List[str] = [] + while idx < len(texts): + candidate = group + [texts[idx]] + try: + messages = build_messages(candidate, [], **kwargs) + except TypeError: + messages = build_messages(candidate, **kwargs) + size = self.estimate(messages) + if size <= self.max_bytes: + group = candidate + idx += 1 + continue + if not group: + raise ValueError("single text block exceeds max_bytes") + break + groups.append(group) + return groups diff --git a/backend/app/gpt/test.py b/backend/app/gpt/test.py new file mode 100644 index 0000000000000000000000000000000000000000..dec785eb8e30db4c6290d0cffe67ccb12bb82be0 --- /dev/null +++ b/backend/app/gpt/test.py @@ -0,0 +1,17 @@ +from app.models.model_config import ModelConfig + +if __name__ == '__main__': + from app.gpt.gpt_factory import GPTFactory + # 构建模型config + config=ModelConfig( + id='asas', + api_key='', + base_url='', + model_name="gpt-4o", + provider='openai', + name='gpt-4o' + ) + # 构建GPT + gpt=GPTFactory().from_config(config) + + diff --git a/backend/app/gpt/tools.py b/backend/app/gpt/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backend/app/gpt/universal_gpt.py b/backend/app/gpt/universal_gpt.py new file mode 100644 index 0000000000000000000000000000000000000000..ca4982ad6d0441508d3f4712f50df8e87ec0d77c --- /dev/null +++ b/backend/app/gpt/universal_gpt.py @@ -0,0 +1,358 @@ +from app.gpt.base import GPT +from app.gpt.prompt_builder import generate_base_prompt +from app.models.gpt_model import GPTSource +import os +import hashlib +import json +import time +from datetime import datetime, timezone +from pathlib import Path + +from app.gpt.prompt import BASE_PROMPT, AI_SUM, SCREENSHOT, LINK, MERGE_PROMPT +from app.gpt.utils import fix_markdown, strip_think_blocks +from app.gpt.request_chunker import RequestChunker +from app.models.transcriber_model import TranscriptSegment +from datetime import timedelta +from typing import List + + +class UniversalGPT(GPT): + def __init__(self, client, model: str, temperature: float = 0.7): + self.client = client + self.model = model + self.temperature = temperature + self.screenshot = False + self.link = False + # 本次 summarize 累计的 token 用量(跨分块/合并多次调用求和) + self.total_tokens = 0 + self.max_request_bytes = int(os.getenv("OPENAI_MAX_REQUEST_BYTES", str(45 * 1024 * 1024))) + self.checkpoint_dir = Path(os.getenv("NOTE_OUTPUT_DIR", "note_results")) + self.checkpoint_dir.mkdir(parents=True, exist_ok=True) + # 初始化时缓存重试配置,避免每次请求重复读取环境变量 + self._max_retry_attempts = max(1, int(os.getenv("OPENAI_RETRY_ATTEMPTS", "3"))) + self._retry_base_backoff = float(os.getenv("OPENAI_RETRY_BACKOFF_SECONDS", "1.5")) + + def _format_time(self, seconds: float) -> str: + return str(timedelta(seconds=int(seconds)))[2:] + + def _build_segment_text(self, segments: List[TranscriptSegment]) -> str: + return "\n".join( + f"{self._format_time(seg.start)} - {seg.text.strip()}" + for seg in segments + ) + + def ensure_segments_type(self, segments) -> List[TranscriptSegment]: + return [TranscriptSegment(**seg) if isinstance(seg, dict) else seg for seg in segments] + + def create_messages(self, segments: List[TranscriptSegment], **kwargs): + + content_text = generate_base_prompt( + title=kwargs.get('title'), + segment_text=self._build_segment_text(segments), + tags=kwargs.get('tags'), + _format=kwargs.get('_format'), + style=kwargs.get('style'), + extras=kwargs.get('extras'), + ) + + video_img_urls = kwargs.get('video_img_urls', []) + + content: list[dict] | str + if video_img_urls: + # 有截图时走 OpenAI 多模态 content 数组(text + image_url)。 + # 不要带 "detail" 字段:OpenAI 缺省即 auto,而 MiniMax 等兼容接口 + # 会对 detail:"auto" 报 400 invalid image detail (2013),导致带图请求全挂。 + content = [{"type": "text", "text": content_text}] + for url in video_img_urls: + content.append({ + "type": "image_url", + "image_url": { + "url": url + } + }) + else: + # 纯文本场景退回 string content:DeepSeek deepseek-chat 等非多模态模型 + # 不识别 [{"type":"text",...}] 数组形态,会返回 invalid_request_error + # (issue #282)。OpenAI 规范本身也允许 content 为 string。 + content = content_text + + messages = [{ + "role": "user", + "content": content + }] + + return messages + + def list_models(self): + return self.client.models.list() + + def _estimate_messages_bytes(self, messages: list) -> int: + import json + return len(json.dumps(messages, ensure_ascii=False).encode("utf-8")) + + def _build_merge_messages(self, partials: list) -> list: + merge_text = MERGE_PROMPT + "\n\n" + "\n\n---\n\n".join(partials) + # 合并阶段没有图片,直接用 string content 兼容非多模态模型(issue #282) + return [{ + "role": "user", + "content": merge_text + }] + + def _checkpoint_path(self, checkpoint_key: str) -> Path: + safe_key = "".join(ch if ch.isalnum() or ch in ("-", "_") else "_" for ch in checkpoint_key) + return self.checkpoint_dir / f"{safe_key}.gpt.checkpoint.json" + + def _build_source_signature(self, source: GPTSource) -> str: + payload = { + "model": self.model, + "temperature": self.temperature, + "max_request_bytes": self.max_request_bytes, + "title": source.title, + "tags": source.tags, + "format": source._format, + "style": source.style, + "extras": source.extras, + "video_img_urls": source.video_img_urls or [], + "segments": [ + { + "start": getattr(seg, "start", None), + "end": getattr(seg, "end", None), + "text": getattr(seg, "text", "") + } + for seg in source.segment + ], + } + raw = json.dumps(payload, ensure_ascii=False, sort_keys=True) + return hashlib.sha256(raw.encode("utf-8")).hexdigest() + + def _load_checkpoint(self, checkpoint_key: str, source_signature: str) -> dict | None: + path = self._checkpoint_path(checkpoint_key) + if not path.exists(): + return None + try: + data = json.loads(path.read_text(encoding="utf-8")) + if data.get("source_signature") != source_signature: + path.unlink(missing_ok=True) + return None + return data + except Exception: + path.unlink(missing_ok=True) + return None + + def _save_checkpoint(self, checkpoint_key: str, source_signature: str, partials: list, phase: str) -> None: + path = self._checkpoint_path(checkpoint_key) + data = { + "version": 1, + "source_signature": source_signature, + "phase": phase, + "partials": partials, + "updated_at": datetime.now(timezone.utc).isoformat(), + } + tmp_path = path.with_suffix(".tmp") + tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8") + tmp_path.replace(path) + + def _clear_checkpoint(self, checkpoint_key: str) -> None: + self._checkpoint_path(checkpoint_key).unlink(missing_ok=True) + + @staticmethod + def _is_insufficient_quota_error(exc: Exception) -> bool: + raw = str(exc) + return ( + "insufficient_user_quota" in raw + or "预扣费额度失败" in raw + or "insufficient quota" in raw.lower() + ) + + @staticmethod + def _is_retryable_error(exc: Exception) -> bool: + raw = str(exc).lower() + retryable_tokens = ( + "error code: 524", + "bad_response_status_code", + "timed out", + "timeout", + "rate limit", + "error code: 429", + "error code: 500", + "error code: 502", + "error code: 503", + "error code: 504", + "apiconnectionerror", + "connection error", + "service unavailable", + ) + if any(token in raw for token in retryable_tokens): + return True + + status = getattr(exc, "status_code", None) or getattr(exc, "status", None) + return status in {408, 409, 429, 500, 502, 503, 504, 524} + + @staticmethod + def _is_temperature_unsupported_error(exc: Exception) -> bool: + """OpenAI o1/o3/gpt-5 系列等新模型不接受自定义 temperature, + 只允许默认值 1,传 0.7 会报 `'temperature' does not support 0.7 ...`。""" + raw = str(exc).lower() + return "temperature" in raw and ( + "does not support" in raw + or "unsupported_value" in raw + or "only the default" in raw + ) + + def _do_create(self, messages: list): + """单次调用。如果模型拒绝自定义 temperature,就地去掉该参数再试一次 + (不消耗外层的重试次数预算),仍失败则把异常抛给外层重试逻辑。""" + try: + return self.client.chat.completions.create( + model=self.model, + messages=messages, + temperature=self.temperature, + ) + except Exception as exc: + if self._is_temperature_unsupported_error(exc): + print(f"[universal_gpt] 模型 {self.model} 不支持自定义 temperature,改用默认值重试") + return self.client.chat.completions.create( + model=self.model, + messages=messages, + ) + raise + + def _accumulate_usage(self, response) -> None: + """累加单次响应的 token 用量。部分供应商可能不返回 usage,容错跳过。""" + try: + usage = getattr(response, "usage", None) + total = getattr(usage, "total_tokens", None) if usage else None + if total: + self.total_tokens += int(total) + except Exception: + pass + + def _chat_completion_create(self, messages: list): + last_exc = None + for attempt in range(self._max_retry_attempts): + try: + response = self._do_create(messages) + self._accumulate_usage(response) + return response + except Exception as exc: + last_exc = exc + if attempt == self._max_retry_attempts - 1 or not self._is_retryable_error(exc): + raise + sleep_seconds = self._retry_base_backoff * (2 ** attempt) + time.sleep(sleep_seconds) + + if last_exc is not None: + raise last_exc + raise RuntimeError("chat completion failed without exception") + + def _merge_partials(self, partials: list, checkpoint_key: str | None, source_signature: str | None) -> str: + def build_messages(texts, *_args, **_kwargs): + return self._build_merge_messages(texts) + + merge_chunker = RequestChunker( + lambda *_args, **_kwargs: [], + self.max_request_bytes, + self._estimate_messages_bytes + ) + + current_partials = list(partials) + if not current_partials: + # 上游转写为空/分块为零时的兜底:给可读错误,而不是 current_partials[0] 的 IndexError + raise ValueError("没有可总结的内容:转写结果为空或分块失败,请检查转写设置后重试。") + while len(current_partials) > 1: + groups = merge_chunker.group_texts_by_budget(current_partials, build_messages) + new_partials = [] + for group_idx, group in enumerate(groups): + messages = build_messages(group) + try: + response = self._chat_completion_create(messages) + except Exception as exc: + if checkpoint_key and source_signature: + self._save_checkpoint(checkpoint_key, source_signature, current_partials, "merge") + raise + + new_partials.append(strip_think_blocks(response.choices[0].message.content)) + + if checkpoint_key and source_signature: + remaining_partials = [] + for remaining_group in groups[group_idx + 1:]: + remaining_partials.extend(remaining_group) + resumable_partials = new_partials + remaining_partials + self._save_checkpoint(checkpoint_key, source_signature, resumable_partials, "merge") + + current_partials = new_partials + + return current_partials[0] + + def summarize(self, source: GPTSource) -> str: + self.total_tokens = 0 + self.screenshot = source.screenshot + self.link = source.link + source.segment = self.ensure_segments_type(source.segment) + checkpoint_key = source.checkpoint_key + source_signature = self._build_source_signature(source) if checkpoint_key else None + + def message_builder(segments, image_urls, **kwargs): + return self.create_messages(segments, video_img_urls=image_urls, **kwargs) + + chunker = RequestChunker(message_builder, self.max_request_bytes, self._estimate_messages_bytes) + + try: + chunks = chunker.chunk( + source.segment, + source.video_img_urls or [], + title=source.title, + tags=source.tags, + _format=source._format, + style=source.style, + extras=source.extras + ) + except ValueError: + chunks = chunker.chunk( + source.segment, + [], + title=source.title, + tags=source.tags, + _format=source._format, + style=source.style, + extras=source.extras + ) + + partials = [] + if checkpoint_key and source_signature: + checkpoint = self._load_checkpoint(checkpoint_key, source_signature) + if checkpoint and isinstance(checkpoint.get("partials"), list): + partials = checkpoint["partials"] + + if len(partials) > len(chunks): + partials = [] + + for chunk in chunks[len(partials):]: + messages = self.create_messages( + chunk.segments, + title=source.title, + tags=source.tags, + video_img_urls=chunk.image_urls, + _format=source._format, + style=source.style, + extras=source.extras + ) + try: + response = self._chat_completion_create(messages) + except Exception as exc: + if checkpoint_key and source_signature: + self._save_checkpoint(checkpoint_key, source_signature, partials, "summarize") + raise + + partials.append(strip_think_blocks(response.choices[0].message.content)) + if checkpoint_key and source_signature: + self._save_checkpoint(checkpoint_key, source_signature, partials, "summarize") + + if len(partials) == 1: + if checkpoint_key: + self._clear_checkpoint(checkpoint_key) + return partials[0] + merged = self._merge_partials(partials, checkpoint_key, source_signature) + if checkpoint_key: + self._clear_checkpoint(checkpoint_key) + return merged diff --git a/backend/app/gpt/utils.py b/backend/app/gpt/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3a7c1193d9c64b82b900b74c4a552c72294b1aa1 --- /dev/null +++ b/backend/app/gpt/utils.py @@ -0,0 +1,36 @@ +import codecs +import re + + +def fix_markdown(markdown: str) -> str: + return codecs.decode(markdown, 'unicode_escape') + + +# 推理模型(DeepSeek R1、QwQ 等)会把思考过程放进 ... 标签返回; +# 部分供应商/网关还会吞掉起始标签只留下 ,或在输出截断时只有起始标签。 +# 笔记/问答正文落地前统一剥掉,避免思考过程混进用户可见内容。 +_THINK_PAIRED_RE = re.compile(r'.*?', re.IGNORECASE | re.DOTALL) +_THINK_ORPHAN_CLOSE_RE = re.compile(r'', re.IGNORECASE) +_THINK_UNCLOSED_RE = re.compile(r'.*\Z', re.IGNORECASE | re.DOTALL) + + +def strip_think_blocks(text: str | None) -> str: + """剥离模型输出中的思考过程标签,返回干净正文。 + + 覆盖三种形态: + - 成对标签:...(含多段、跨行、 变体,大小写不敏感) + - 只剩孤立 (起始标签被供应商吃掉):取最后一个闭合标签之后的内容 + - 只有 没闭合(输出被截断):丢弃标签起的全部内容 + """ + if not text: + return "" + cleaned = _THINK_PAIRED_RE.sub('', text) + + last_close = None + for last_close in _THINK_ORPHAN_CLOSE_RE.finditer(cleaned): + pass + if last_close: + cleaned = cleaned[last_close.end():] + + cleaned = _THINK_UNCLOSED_RE.sub('', cleaned) + return cleaned.strip() diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backend/app/models/audio_model.py b/backend/app/models/audio_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d2808b56473100a09dece3262eb4b88fdf9e7c8b --- /dev/null +++ b/backend/app/models/audio_model.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class AudioDownloadResult: + file_path: str # 本地音频路径 + title: str # 视频标题 + duration: float # 视频时长(秒) + cover_url: Optional[str] # 视频封面图 + platform: str # 平台,如 "bilibili" + video_id: str # 唯一视频ID + raw_info: dict # yt-dlp 的原始 info 字典 + video_path: Optional[str] = None # 新增字段:可选视频文件路径 + diff --git a/backend/app/models/gpt_model.py b/backend/app/models/gpt_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a5fc936e2b1e8d4af9a94e1e7dff808dce5699ea --- /dev/null +++ b/backend/app/models/gpt_model.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass +from typing import List, Union, Optional + +from app.models.transcriber_model import TranscriptSegment + + +@dataclass +class GPTSource: + segment: Union[List[TranscriptSegment], List] + title: str + tags:str + screenshot: Optional[bool] = False + link: Optional[bool] = False + style: Optional[str] = None + extras: Optional[str] = None + _format: Optional[list] = None + video_img_urls: Optional[list] = None + checkpoint_key: Optional[str] = None + diff --git a/backend/app/models/model_config.py b/backend/app/models/model_config.py new file mode 100644 index 0000000000000000000000000000000000000000..1d466d16d1b9ec6f0bb41e57d00d27e0855c290b --- /dev/null +++ b/backend/app/models/model_config.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass +from datetime import datetime +from typing import Optional + + +@dataclass +class ModelConfig: + """ + 存储每个模型提供商的调用参数信息,用于从数据库读取并动态构建 GPT 调用实例。 + """ + name: str # 展示名,如 "GPT-4 Turbo"(用于前端展示) + provider: str # 模型提供商,如 "openai"、"qwen"、"deepseek" + api_key: str # 调用该模型使用的 API Key + base_url: str # 模型 API 接口地址(OpenAI SDK兼容) + model_name: str # 实际请求用的模型名称,如 "gpt-4-turbo" + created_at: Optional[datetime] = None # 可选:创建时间(从 SQLite 自动生成) \ No newline at end of file diff --git a/backend/app/models/notes_model.py b/backend/app/models/notes_model.py new file mode 100644 index 0000000000000000000000000000000000000000..6bbcad951261b67006c8db3f6683a4dbcd50f408 --- /dev/null +++ b/backend/app/models/notes_model.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass +from typing import Optional + +from app.models.audio_model import AudioDownloadResult +from app.models.transcriber_model import TranscriptResult + + +@dataclass +class NoteResult: + markdown: str # GPT 总结的 Markdown 内容 + transcript: TranscriptResult # Whisper 转写结果 + audio_meta: AudioDownloadResult # 音频下载的元信息(title、duration、封面等) + total_tokens: int = 0 # 本次生成消耗的 LLM token 总量(0 表示供应商未返回) \ No newline at end of file diff --git a/backend/app/models/provide_model.py b/backend/app/models/provide_model.py new file mode 100644 index 0000000000000000000000000000000000000000..01009251843e255ddc83be3e75fde37cf97e66a0 --- /dev/null +++ b/backend/app/models/provide_model.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass +from datetime import datetime +from typing import Optional + + +@dataclass +class ProviderModel: + """ + 存储每个模型提供商的调用参数信息,用于从数据库读取并动态构建 GPT 调用实例。 + """ + id: str # 模型唯一 ID(推荐用 UUID) + logo: str # 模型图标 URL + name: str # 展示名,如 "GPT-4 Turbo"(用于前端展示) + api_key: str # 调用该模型使用的 API Key + base_url: str # 模型 API 接口地址(OpenAI SDK兼容) + created_at: Optional[datetime] = None # 可选:创建时间(从 SQLite 自动生成) \ No newline at end of file diff --git a/backend/app/models/transcriber_model.py b/backend/app/models/transcriber_model.py new file mode 100644 index 0000000000000000000000000000000000000000..8d4bd246d34464d011f48b5f63907510a892003b --- /dev/null +++ b/backend/app/models/transcriber_model.py @@ -0,0 +1,16 @@ + +from dataclasses import dataclass +from typing import List, Optional + +@dataclass +class TranscriptSegment: + start: float # 开始时间(秒) + end: float # 结束时间(秒) + text: str # 该段文字 + +@dataclass +class TranscriptResult: + language: Optional[str] # 检测语言(如 "zh"、"en") + full_text: str # 完整合并后的文本(用于摘要) + segments: List[TranscriptSegment] # 分段结构,适合前端显示时间轴字幕等 + raw: Optional[dict] = None # 原始响应数据,便于调试或平台特性处理 \ No newline at end of file diff --git a/backend/app/models/video_record.py b/backend/app/models/video_record.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backend/app/routers/__init__.py b/backend/app/routers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backend/app/routers/article.py b/backend/app/routers/article.py new file mode 100644 index 0000000000000000000000000000000000000000..7c92434795d264b2ae217134feb4c1a767916489 --- /dev/null +++ b/backend/app/routers/article.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from typing import Optional + +from fastapi import APIRouter +from pydantic import BaseModel + +from app.article_fetchers.base import ArticleFetchError +from app.exceptions.biz_exception import BizException +from app.services.article import ArticleService +from app.utils.response import ResponseWrapper as R + +router = APIRouter() + + +class ArticleGenerateRequest(BaseModel): + url: str + platform: str + provider_id: str + model_name: str + style: str = "" + extras: str = "" + task_id: Optional[str] = None + + +class ArticleImportRequest(BaseModel): + url: str = "" + platform: str = "generic_web" + title: str = "" + content_text: str + author_name: str = "" + provider_id: str + model_name: str + style: str = "" + extras: str = "" + task_id: Optional[str] = None + + +class SubscriptionRequest(BaseModel): + platform: str + type: str + query: str + label: str = "" + + +class SummarizeItemRequest(BaseModel): + provider_id: str + model_name: str + style: str = "" + extras: str = "" + + +@router.post("/articles/generate") +def generate_article(data: ArticleGenerateRequest): + try: + return R.success( + ArticleService().generate_from_url( + url=data.url, + platform=data.platform, + provider_id=data.provider_id, + model_name=data.model_name, + style=data.style, + extras=data.extras, + task_id=data.task_id, + ) + ) + except (ArticleFetchError, ValueError) as exc: + raise BizException(400400, str(exc)) from exc + + +@router.post("/articles/import_content") +def import_article_content(data: ArticleImportRequest): + try: + return R.success( + ArticleService().generate_from_content( + url=data.url, + platform=data.platform, + title=data.title, + content_text=data.content_text, + author_name=data.author_name, + provider_id=data.provider_id, + model_name=data.model_name, + style=data.style, + extras=data.extras, + task_id=data.task_id, + ) + ) + except (ArticleFetchError, ValueError) as exc: + raise BizException(400400, str(exc)) from exc + + +@router.get("/articles/search") +def search_articles(platform: str, keyword: str, limit: int = 20): + try: + return R.success(ArticleService().search(platform=platform, keyword=keyword, limit=limit)) + except (ArticleFetchError, ValueError) as exc: + raise BizException(400400, str(exc)) from exc + + +@router.post("/article_subscriptions") +def create_article_subscription(data: SubscriptionRequest): + return R.success( + ArticleService().create_subscription( + platform=data.platform, + subscription_type=data.type, + query=data.query, + label=data.label, + ) + ) + + +@router.get("/article_subscriptions") +def get_article_subscriptions(): + return R.success(ArticleService().list_subscriptions()) + + +@router.post("/article_subscriptions/{subscription_id}/refresh") +def refresh_article_subscription(subscription_id: int, limit: int = 20): + try: + return R.success(ArticleService().refresh_subscription(subscription_id, limit=limit)) + except (ArticleFetchError, ValueError) as exc: + raise BizException(400400, str(exc)) from exc + + +@router.get("/article_items") +def get_article_items(subscription_id: Optional[int] = None): + return R.success(ArticleService().list_items(subscription_id=subscription_id)) + + +@router.get("/article_items/{item_id}") +def get_article_item(item_id: int): + try: + return R.success(ArticleService().get_item(item_id)) + except ValueError as exc: + raise BizException(400404, str(exc)) from exc + + +@router.post("/article_items/{item_id}/summarize") +def summarize_article_item(item_id: int, data: SummarizeItemRequest): + try: + return R.success( + ArticleService().summarize_item( + item_id=item_id, + provider_id=data.provider_id, + model_name=data.model_name, + style=data.style, + extras=data.extras, + ) + ) + except (ArticleFetchError, ValueError) as exc: + raise BizException(400400, str(exc)) from exc diff --git a/backend/app/routers/chat.py b/backend/app/routers/chat.py new file mode 100644 index 0000000000000000000000000000000000000000..a7f79505823fda052d17974145fc69cc56e748a8 --- /dev/null +++ b/backend/app/routers/chat.py @@ -0,0 +1,184 @@ +from typing import Optional + +from fastapi import APIRouter, BackgroundTasks +from pydantic import BaseModel + +from app.services.chat_service import chat as chat_service, chat_across as chat_across_service +from app.services.vector_store import VectorStoreManager +from app.utils.logger import get_logger +from app.utils.response import ResponseWrapper as R + +logger = get_logger(__name__) + +router = APIRouter() + +# 索引状态追踪: task_id -> "indexing" | "indexed" | "failed" +_index_status: dict[str, str] = {} + + +class IndexRequest(BaseModel): + task_id: str + + +class ChatMessage(BaseModel): + role: str + content: str + + +class AskRequest(BaseModel): + task_id: str + question: str + history: list[ChatMessage] = [] + provider_id: str + model_name: str + + +def _do_index(task_id: str): + """后台执行索引任务。""" + try: + _index_status[task_id] = "indexing" + store = VectorStoreManager() + store.index_task(task_id) + _index_status[task_id] = "indexed" + logger.info(f"索引完成: {task_id}") + except Exception as e: + _index_status[task_id] = "failed" + logger.error(f"索引失败: {task_id}, {e}") + + +@router.post("/chat/index") +def index_task(data: IndexRequest, background_tasks: BackgroundTasks): + """触发后台索引,立即返回。""" + if _index_status.get(data.task_id) == "indexing": + return R.success(msg="正在索引中") + + # 如果已经索引过,直接返回 + store = VectorStoreManager() + if store.is_indexed(data.task_id): + _index_status[data.task_id] = "indexed" + return R.success(msg="已完成索引") + + _index_status[data.task_id] = "indexing" + background_tasks.add_task(_do_index, data.task_id) + return R.success(msg="开始索引") + + +@router.get("/chat/status") +def chat_status(task_id: str): + """返回索引状态:idle / indexing / indexed / failed。""" + try: + # 优先检查内存状态 + status = _index_status.get(task_id) + if status: + return R.success(data={"status": status, "indexed": status == "indexed"}) + + # 内存没有记录,检查持久化 + store = VectorStoreManager() + indexed = store.is_indexed(task_id) + if indexed: + _index_status[task_id] = "indexed" + return R.success(data={"status": "indexed" if indexed else "idle", "indexed": indexed}) + except Exception as e: + logger.error(f"查询索引状态失败: {e}") + return R.success(data={"status": "idle", "indexed": False}) + + +@router.post("/chat/ask") +def ask_question(data: AskRequest): + """基于笔记内容的 RAG 问答。""" + try: + history = [{"role": m.role, "content": m.content} for m in data.history] + result = chat_service( + task_id=data.task_id, + question=data.question, + history=history, + provider_id=data.provider_id, + model_name=data.model_name, + ) + return R.success(data=result) + except ValueError as e: + return R.error(msg=str(e)) + except Exception as e: + logger.error(f"Chat 问答失败: {e}", exc_info=True) + return R.error(msg=f"问答失败: {str(e)}") + + +# ── 跨笔记知识库问答 ───────────────────────────────────────── + + +class AskAcrossScope(BaseModel): + """ + 跨笔记检索的过滤条件。 + - task_ids=None → 全库 + - task_ids=[] → 没匹配到任何笔记(合集筛选后为空时使用,由前端解析) + """ + task_ids: Optional[list[str]] = None + + +class AskAcrossRequest(BaseModel): + question: str + history: list[ChatMessage] = [] + scope: AskAcrossScope = AskAcrossScope() + provider_id: str + model_name: str + + +@router.post("/chat/ask_across") +def ask_across(data: AskAcrossRequest): + """跨多篇笔记的知识库问答。前端把合集/平台/时间过滤解析成 task_ids 列表传入。""" + try: + history = [{"role": m.role, "content": m.content} for m in data.history] + result = chat_across_service( + question=data.question, + history=history, + scope={"task_ids": data.scope.task_ids}, + provider_id=data.provider_id, + model_name=data.model_name, + ) + return R.success(data=result) + except ValueError as e: + return R.error(msg=str(e)) + except Exception as e: + logger.error(f"跨笔记问答失败: {e}", exc_info=True) + return R.error(msg=f"问答失败: {str(e)}") + + +@router.get("/chat/indexed_tasks") +def list_indexed_tasks(): + """返回所有已建立向量索引的 task_id,供前端「重建/统计」用。""" + try: + store = VectorStoreManager() + return R.success(data={"task_ids": store.list_indexed_task_ids()}) + except Exception as e: + logger.error(f"列出索引失败: {e}") + return R.error(msg=str(e)) + + +def _do_reindex_all(task_ids: list[str]): + """后台批量重建索引。""" + store = VectorStoreManager() + for tid in task_ids: + try: + store.index_task(tid) + _index_status[tid] = "indexed" + except Exception as e: + _index_status[tid] = "failed" + logger.error(f"重建索引失败 task_id={tid}: {e}") + logger.info(f"批量重建索引完成,共 {len(task_ids)} 个") + + +class ReindexAllRequest(BaseModel): + task_ids: Optional[list[str]] = None # None = 重建所有已索引的 + + +@router.post("/chat/reindex_all") +def reindex_all(data: ReindexAllRequest, background_tasks: BackgroundTasks): + """后台批量重建索引(兜底用,不阻塞返回)。task_ids=None 时重建所有已索引的笔记。""" + store = VectorStoreManager() + targets = data.task_ids if data.task_ids is not None else store.list_indexed_task_ids() + if not targets: + return R.success(msg="没有需要重建的索引", data={"count": 0}) + for tid in targets: + _index_status[tid] = "indexing" + background_tasks.add_task(_do_reindex_all, targets) + return R.success(msg=f"已开始后台重建 {len(targets)} 个索引", data={"count": len(targets)}) diff --git a/backend/app/routers/config.py b/backend/app/routers/config.py new file mode 100644 index 0000000000000000000000000000000000000000..a39cc7ca5b2ddd86f495d807beeb96c4579a901a --- /dev/null +++ b/backend/app/routers/config.py @@ -0,0 +1,647 @@ +import os +import platform +from pathlib import Path + +from fastapi import APIRouter, HTTPException, BackgroundTasks +from pydantic import BaseModel +from typing import Optional +from app.utils.response import ResponseWrapper as R +from app.utils.logger import get_logger +from app.utils.path_helper import get_model_dir + +from app.services.cookie_manager import CookieConfigManager +from app.services.transcriber_config_manager import TranscriberConfigManager +from ffmpeg_helper import ensure_ffmpeg_or_raise + +logger = get_logger(__name__) + +router = APIRouter() +cookie_manager = CookieConfigManager() +transcriber_config_manager = TranscriberConfigManager() + + +class CookieUpdateRequest(BaseModel): + platform: str + cookie: str + # 可选:「从浏览器读取 cookie」配置;为空字符串表示清除该设置。 + # 支持的值参见 yt-dlp 文档(chrome/firefox/safari/edge/brave/chromium/opera/vivaldi/whale)。 + browser: Optional[str] = None + + +class BrowserCookieSyncRequest(BaseModel): + platform: str + browser: str + + +@router.get("/get_downloader_cookie/{platform}") +def get_cookie(platform: str): + cookie = cookie_manager.get(platform) or "" + browser = cookie_manager.get_browser(platform) or "" + if not cookie and not browser: + return R.success(msg='未找到Cookies', data={"platform": platform, "cookie": "", "browser": ""}) + return R.success( + data={"platform": platform, "cookie": cookie, "browser": browser} + ) + + +class CustomPlatformRequest(BaseModel): + key: str + name: str + match: str + + +@router.get("/custom_platforms") +def list_custom_platforms(): + from app.services import custom_platform_manager + return R.success(data=custom_platform_manager.list_all()) + + +@router.post("/custom_platforms") +def upsert_custom_platform(data: CustomPlatformRequest): + from app.services import custom_platform_manager + try: + item = custom_platform_manager.upsert(data.key, data.name, data.match) + return R.success(data=item) + except ValueError as e: + return R.error(msg=str(e)) + + +@router.delete("/custom_platforms/{key}") +def delete_custom_platform(key: str): + from app.services import custom_platform_manager + from app.services.cookie_manager import CookieConfigManager + if custom_platform_manager.delete(key): + # 顺便清掉关联的 cookie 记录,保持配置文件整洁 + CookieConfigManager().delete(key) + return R.success(msg="已删除") + return R.error(msg="未找到该自定义平台") + + +@router.post("/update_downloader_cookie") +def update_cookie(data: CookieUpdateRequest): + cookie = (data.cookie or "").strip() + browser = (data.browser or "").strip() if data.browser is not None else None + # 两者都空 → 视为清除整条配置,保持 config 文件整洁 + if not cookie and (browser == "" or browser is None): + cookie_manager.delete(data.platform) + else: + cookie_manager.set(data.platform, cookie, browser=browser if browser is not None else None) + return R.success() + + +@router.post("/sync_downloader_cookie_from_browser") +def sync_cookie_from_browser(data: BrowserCookieSyncRequest): + from app.services.browser_cookie import BrowserCookieError, sync_browser_cookie + + try: + result = sync_browser_cookie(data.platform, data.browser, manager=cookie_manager) + return R.success(data=result, msg=f"已从浏览器读取 {result['count']} 条 Cookie") + except BrowserCookieError as exc: + return R.error(msg=str(exc)) + + +class TranscriberConfigRequest(BaseModel): + transcriber_type: str + whisper_model_size: Optional[str] = None + whisper_custom_model: Optional[str] = None + funasr_model: Optional[str] = None + + +AVAILABLE_TRANSCRIBER_TYPES = [ + {"value": "fast-whisper", "label": "Faster Whisper(本地)"}, + {"value": "bcut", "label": "必剪(在线)"}, + {"value": "kuaishou", "label": "快手(在线)"}, + {"value": "groq", "label": "Groq(在线)"}, + {"value": "mlx-whisper", "label": "MLX Whisper(仅macOS)"}, + {"value": "funasr", "label": "FunASR(阿里·中文,需装依赖)"}, +] + +# "custom" 末项:用户自定义本地/HF whisper 模型(路径见 whisper_custom_model) +WHISPER_MODEL_SIZES = ["tiny", "base", "small", "medium", "large-v3", "large-v3-turbo", "custom"] + + +@router.get("/transcriber_config") +def get_transcriber_config(): + import sys + from app.transcriber.transcriber_provider import MLX_WHISPER_AVAILABLE, FUNASR_AVAILABLE + + config = transcriber_config_manager.get_config() + + # mlx_whisper 不可用时给前端精确的安装指引: + # - 桌面端(冻结):装到插件目录(main.py 启动时已加进 sys.path),必须用 Python 3.11 + # - 源码/Docker:直接装进后端环境 + if getattr(sys, "frozen", False): + from app.utils.path_helper import get_plugin_packages_dir + plugin_dir = get_plugin_packages_dir() + mlx_install_command = f'python3.11 -m pip install --target "{plugin_dir}" mlx_whisper' + mlx_install_note = ( + "桌面版应用内置 Python 3.11,必须用同版本 Python 安装(macOS 可先 " + "brew install python@3.11)。安装完成后重启应用生效。" + ) + else: + plugin_dir = "" + mlx_install_command = "pip install mlx_whisper" + mlx_install_note = "安装到后端运行环境(venv)后重启后端生效。" + + return R.success(data={ + **config, + "available_types": AVAILABLE_TRANSCRIBER_TYPES, + "whisper_model_sizes": WHISPER_MODEL_SIZES, + "mlx_whisper_available": MLX_WHISPER_AVAILABLE, + "mlx_install_command": mlx_install_command, + "mlx_install_note": mlx_install_note, + "mlx_plugin_dir": plugin_dir, + # FunASR 可选引擎:未安装时前端给安装指引并禁用保存。 + # 桌面冻结包不支持(torch 与 PyInstaller 运行时不兼容,装进插件目录会让应用无法启动), + # 此时不下发安装命令,只给说明。 + "funasr_available": FUNASR_AVAILABLE, + "funasr_install_command": "" if getattr(sys, "frozen", False) else "pip install funasr torch torchaudio", + "funasr_install_note": ( + "桌面版暂不支持 FunASR:其依赖的 PyTorch 与桌面打包运行时不兼容," + "强行安装到插件目录会导致应用无法启动。如需 FunASR 请使用源码或 Docker 部署。" + if getattr(sys, "frozen", False) + else "FunASR 依赖 PyTorch(约 2GB),属可选引擎,安装到后端运行环境(venv)后重启生效;" + "中文识别效果通常优于 Whisper,模型首次使用经 modelscope 自动下载。" + ), + }) + + +@router.post("/transcriber_config") +def update_transcriber_config(data: TranscriberConfigRequest): + config = transcriber_config_manager.update_config( + transcriber_type=data.transcriber_type, + whisper_model_size=data.whisper_model_size, + whisper_custom_model=data.whisper_custom_model, + funasr_model=data.funasr_model, + ) + return R.success(data=config) + + +# ---- 全局代理配置(作用于 LLM API + 转写 API + yt-dlp 下载)---- + +class ProxyConfigRequest(BaseModel): + enabled: bool + url: Optional[str] = None + + +@router.get("/proxy_config") +def get_proxy_config(): + from app.services.proxy_config_manager import ProxyConfigManager + mgr = ProxyConfigManager() + cfg = mgr.get_config() + # effective 给前端展示「当前实际生效的代理」——可能来自配置,也可能来自 env 兜底 + return R.success(data={ + **cfg, + "effective": mgr.get_proxy_url() or "", + }) + + +@router.post("/proxy_config") +def update_proxy_config(data: ProxyConfigRequest): + from app.services.proxy_config_manager import ProxyConfigManager + mgr = ProxyConfigManager() + cfg = mgr.update_config(enabled=data.enabled, url=data.url) + return R.success(data={ + **cfg, + "effective": mgr.get_proxy_url() or "", + }) + + +# ---- Whisper 模型下载状态 & 下载触发 ---- + +# 用于跟踪正在进行的下载任务 +_downloading: dict[str, str] = {} # model_size -> status ("downloading" | "done" | "failed") + + +def _check_whisper_model_exists(model_size: str, subdir: str = "whisper") -> bool: + """检查指定 whisper 模型是否已下载完整到本地。 + + faster-whisper 把模型缓存在 HF cache 布局下: + /models--Systran--faster-whisper-{size}/snapshots//model.bin + 必须能在某个 snapshot 目录里找到 model.bin 才算完成。 + (历史 modelscope 布局 /whisper-{size}/model.bin 也兼容识别。) + """ + model_dir = Path(get_model_dir(subdir)) + # HF cache 布局 + hf_repo_dir = model_dir / f"models--Systran--faster-whisper-{model_size}" / "snapshots" + if hf_repo_dir.exists(): + for snapshot in hf_repo_dir.iterdir(): + if (snapshot / "model.bin").exists(): + return True + # 历史 modelscope 布局(向后兼容老用户) + legacy = model_dir / f"whisper-{model_size}" / "model.bin" + return legacy.exists() + + +def _check_mlx_whisper_model_exists(model_size: str) -> bool: + """检查 mlx-whisper 模型是否已下载完整到本地。 + + 与 fast-whisper 的目录布局不同:mlx 模型按 HuggingFace repo_id + (如 mlx-community/whisper-tiny-mlx)落盘,且没有 model.bin, + 用 config.json 作为「下载完成」的判据,和 mlx_whisper_transcriber.py 保持一致。 + """ + try: + from app.transcriber.mlx_whisper_transcriber import MLX_MODEL_MAP + except Exception: + return False + repo_id = MLX_MODEL_MAP.get(model_size) + if not repo_id: + return False + model_dir = get_model_dir("mlx-whisper") + model_path = os.path.join(model_dir, repo_id) + return (Path(model_path) / "config.json").exists() + + +# ---- FunASR 模型预下载 ---- +# 常用 FunASR 模型 → 实际需要的 modelscope 仓库(主模型 + 流水线依赖的 vad/punc)。 +# 预下载用 modelscope.snapshot_download 落到 funasr AutoModel 同一份缓存, +# 这样首个任务不再边跑边下(曾因下载中断产生损坏的 punc 模型)。 +_FUNASR_VAD_REPO = "iic/speech_fsmn_vad_zh-cn-16k-common-pytorch" +_FUNASR_PUNC_REPO = "iic/punc_ct-transformer_cn-en-common-vocab471067-large" +FUNASR_MODEL_REPOS: dict = { + "paraformer-zh": [ + "iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", + _FUNASR_VAD_REPO, + _FUNASR_PUNC_REPO, + ], + "SenseVoiceSmall": ["iic/SenseVoiceSmall", _FUNASR_VAD_REPO], + "paraformer-zh-streaming": [ + "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", + _FUNASR_VAD_REPO, + _FUNASR_PUNC_REPO, + ], +} + + +def _modelscope_cache_root() -> Path: + """modelscope 默认缓存根(funasr AutoModel 下载落点相同)。""" + return Path(os.path.expanduser(os.getenv("MODELSCOPE_CACHE", "~/.cache/modelscope"))) / "hub" / "models" + + +def _check_funasr_model_exists(name: str) -> bool: + """FunASR 模型(含其 vad/punc 依赖)是否已全部落盘。以 model.pt 存在为判据。""" + repos = FUNASR_MODEL_REPOS.get(name) + if not repos: + return False # 未知/自定义模型不做预下载判定,首跑时按需下载 + return all((_modelscope_cache_root() / r / "model.pt").exists() for r in repos) + + +def _do_download_funasr(name: str): + """后台预下载 FunASR 模型及其依赖(modelscope 自带校验,可断点续传/修复)。""" + key = f"funasr-{name}" + try: + _downloading[key] = "downloading" + from modelscope.hub.snapshot_download import snapshot_download as ms_download + + for repo in FUNASR_MODEL_REPOS.get(name, []): + if (_modelscope_cache_root() / repo / "model.pt").exists(): + continue + logger.info(f"下载 FunASR 模型: {repo}") + ms_download(repo) + logger.info(f"FunASR 模型下载完成: {name}") + _downloading[key] = "done" + except Exception as e: + logger.error(f"FunASR 模型下载失败: {name}, {e}") + _downloading[key] = "failed" + + +@router.get("/transcriber_models_status") +def get_transcriber_models_status(): + """返回所有 whisper 模型的下载状态。""" + statuses = [] + for size in WHISPER_MODEL_SIZES: + downloaded = _check_whisper_model_exists(size, "whisper") + download_status = _downloading.get(size) + statuses.append({ + "model_size": size, + "downloaded": downloaded, + "downloading": download_status == "downloading", + }) + + # 也检查 mlx-whisper(仅 macOS) + # 注意:import mlx_whisper 会 dlopen MLX 原生库,若打包缺少 libjaccl.dylib 等会抛 ImportError。 + # 必须 try/except 兜住——否则 mlx 不可用时会把上面已算好的 fast-whisper 状态一起 500 掉, + # 导致前端「模型管理」整张卡(含 fast-whisper 下载按钮)都不渲染。 + mlx_available = platform.system() == "Darwin" + mlx_statuses = [] + if mlx_available: + try: + from app.transcriber.mlx_whisper_transcriber import MLX_MODEL_MAP + for size in WHISPER_MODEL_SIZES: + mlx_key = f"mlx-{size}" + repo_id = MLX_MODEL_MAP.get(size) + # 用 config.json 判定,和 _check_mlx_whisper_model_exists / 加载逻辑保持一致 + downloaded = _check_mlx_whisper_model_exists(size) + mlx_statuses.append({ + "model_size": size, + "downloaded": downloaded, + "downloading": _downloading.get(mlx_key) == "downloading", + "available": repo_id is not None, + }) + except Exception as e: + logger.warning(f"mlx-whisper 不可用(原生库加载失败等),降级跳过其模型状态: {e}") + mlx_available = False + mlx_statuses = [] + + # FunASR 模型(预下载状态;不依赖 funasr 包,下载走 modelscope) + funasr_statuses = [ + { + "model_size": name, + "downloaded": _check_funasr_model_exists(name), + "downloading": _downloading.get(f"funasr-{name}") == "downloading", + } + for name in FUNASR_MODEL_REPOS + ] + + return R.success(data={ + "whisper": statuses, + "mlx_whisper": mlx_statuses, + "mlx_available": mlx_available, + "funasr": funasr_statuses, + }) + + +class ModelDownloadRequest(BaseModel): + model_size: str + transcriber_type: str = "fast-whisper" # "fast-whisper" 或 "mlx-whisper" + + +def _do_download_whisper(model_size: str): + """后台下载 faster-whisper 模型。 + + 直接走 huggingface_hub.snapshot_download,把模型放到 HF cache 布局里—— + 这样 faster-whisper 加载时(WhisperModel(model_size_or_path=size_name, + download_root=model_dir))能直接命中缓存,跟加载路径完全对齐。 + """ + from huggingface_hub import snapshot_download + + try: + _downloading[model_size] = "downloading" + model_dir = get_model_dir("whisper") + + # 已经下好就不重复下 + if _check_whisper_model_exists(model_size, "whisper"): + _downloading[model_size] = "done" + return + repo_id = f"Systran/faster-whisper-{model_size}" + logger.info(f"开始下载 whisper 模型: {repo_id}") + # 跟 faster-whisper utils.py 用同样的 allow_patterns,避免多下无关文件; + # 不传 local_dir 让它走 HF 默认 cache 布局(与加载逻辑对齐) + snapshot_download( + repo_id, + cache_dir=model_dir, + allow_patterns=[ + "config.json", + "preprocessor_config.json", + "model.bin", + "tokenizer.json", + "vocabulary.*", + ], + ) + logger.info(f"whisper 模型下载完成: {model_size}") + _downloading[model_size] = "done" + except Exception as e: + logger.error(f"whisper 模型下载失败: {model_size}, {e}") + _downloading[model_size] = "failed" + + +def _do_download_mlx_whisper(model_size: str): + """后台下载 mlx-whisper 模型。""" + key = f"mlx-{model_size}" + try: + _downloading[key] = "downloading" + from huggingface_hub import snapshot_download as hf_download + from app.transcriber.mlx_whisper_transcriber import resolve_mlx_repo_id + + try: + repo_id = resolve_mlx_repo_id(model_size) + except ValueError as e: + logger.error(str(e)) + _downloading[key] = "failed" + return + + model_dir = get_model_dir("mlx-whisper") + model_path = os.path.join(model_dir, repo_id) + # 用 config.json 判定而非目录存在:半成品目录不能算「已下载」 + if (Path(model_path) / "config.json").exists(): + _downloading[key] = "done" + return + logger.info(f"开始下载 mlx-whisper 模型: {model_size} ← {repo_id}") + hf_download(repo_id, local_dir=model_path, local_dir_use_symlinks=False) + logger.info(f"mlx-whisper 模型下载完成: {model_size}") + _downloading[key] = "done" + except Exception as e: + logger.error(f"mlx-whisper 模型下载失败: {model_size}, {e}") + _downloading[key] = "failed" + + +class ModelDeleteRequest(BaseModel): + model_size: str + transcriber_type: str = "fast-whisper" # "fast-whisper" / "mlx-whisper" / "funasr" + + +@router.post("/transcriber_delete") +def delete_transcriber_model(data: ModelDeleteRequest): + """卸载(删除)已下载到本地的转写模型,释放磁盘空间;可随时重新下载。""" + import shutil + + size = data.model_size + ttype = data.transcriber_type + + # 下载中的模型不允许删,避免半删半下产生损坏缓存 + dl_key = {"mlx-whisper": f"mlx-{size}", "funasr": f"funasr-{size}"}.get(ttype, size) + if _downloading.get(dl_key) == "downloading": + return R.error(msg="该模型正在下载中,请等待下载完成后再卸载") + + targets: list = [] + if ttype == "fast-whisper": + if size not in WHISPER_MODEL_SIZES: + return R.error(msg=f"未知模型: {size}") + model_dir = Path(get_model_dir("whisper")) + targets = [ + model_dir / f"models--Systran--faster-whisper-{size}", + model_dir / f"whisper-{size}", # 历史 modelscope 布局 + ] + elif ttype == "mlx-whisper": + try: + from app.transcriber.mlx_whisper_transcriber import resolve_mlx_repo_id + repo = resolve_mlx_repo_id(size) + except Exception as e: + return R.error(msg=f"未知模型: {size} ({e})") + targets = [Path(get_model_dir("mlx-whisper")) / repo] + elif ttype == "funasr": + repos = FUNASR_MODEL_REPOS.get(size) + if not repos: + return R.error(msg=f"未知模型: {size}") + # 共享依赖保护:vad/punc 被多个 FunASR 模型共用, + # 只删「其他已下载模型」不再需要的仓库 + keep = set() + for other, other_repos in FUNASR_MODEL_REPOS.items(): + if other != size and _check_funasr_model_exists(other): + keep.update(other_repos) + targets = [_modelscope_cache_root() / r for r in repos if r not in keep] + else: + return R.error(msg=f"未知转写器类型: {ttype}") + + removed = 0 + for t in targets: + if t.exists(): + shutil.rmtree(t, ignore_errors=True) + removed += 1 + logger.info(f"已卸载模型目录: {t}") + _downloading.pop(dl_key, None) # 清掉历史下载状态,避免显示残留 + + if removed == 0: + return R.success(msg="模型不存在或已卸载") + return R.success(msg="模型已卸载") + + +@router.post("/transcriber_download") +def download_transcriber_model(data: ModelDownloadRequest, background_tasks: BackgroundTasks): + """触发后台下载指定的转写模型(whisper / mlx-whisper / funasr)。""" + # FunASR:model_size 字段承载 FunASR 模型名(复用既有请求结构) + if data.transcriber_type == "funasr": + if data.model_size not in FUNASR_MODEL_REPOS: + return R.error(msg=f"不支持预下载的 FunASR 模型: {data.model_size}") + key = f"funasr-{data.model_size}" + if _downloading.get(key) == "downloading": + return R.success(msg="模型正在下载中") + background_tasks.add_task(_do_download_funasr, data.model_size) + return R.success(msg="模型下载已开始") + + if data.model_size not in WHISPER_MODEL_SIZES: + return R.error(msg=f"不支持的模型大小: {data.model_size}") + + if data.transcriber_type == "mlx-whisper": + if platform.system() != "Darwin": + return R.error(msg="MLX Whisper 仅支持 macOS") + key = f"mlx-{data.model_size}" + if _downloading.get(key) == "downloading": + return R.success(msg="模型正在下载中") + background_tasks.add_task(_do_download_mlx_whisper, data.model_size) + else: + if _downloading.get(data.model_size) == "downloading": + return R.success(msg="模型正在下载中") + background_tasks.add_task(_do_download_whisper, data.model_size) + + return R.success(msg="模型下载已开始") + + +@router.get("/sys_health") +async def sys_health(): + """结构化健康状态——任何子项异常都不应让整个 endpoint 5xx。 + + 每个字段:'ok' | 'missing' | 'error'。 + 前端 useCheckBackend 用 /sys_check 做存活判定(不依赖外部依赖), + /sys_health 用来在设置页区分「后端没起」vs「后端起了但 ffmpeg 缺」vs「DB 写不进去」等更细的状态。 + """ + ffmpeg_status = "ok" + try: + ensure_ffmpeg_or_raise() + except Exception: + ffmpeg_status = "missing" + + db_status = "ok" + try: + from app.db.engine import engine + from sqlalchemy import text + with engine.connect() as conn: + conn.execute(text("SELECT 1")) + except Exception: + db_status = "error" + + # 当前转写器配置 + 模型是否已下载(用 model.bin 落盘判定,与 transcriber 加载逻辑一致) + whisper_info: dict = {"size": None, "type": None, "downloaded": False, "checked": False} + try: + cfg = transcriber_config_manager.get_config() + size = cfg["whisper_model_size"] + ttype = cfg["transcriber_type"] + whisper_info["size"] = size + whisper_info["type"] = ttype + # 只有本地引擎才有「下载」概念;groq / bcut / kuaishou 在线引擎跳过 + if ttype == "fast-whisper": + whisper_info["downloaded"] = _check_whisper_model_exists(size, "whisper") + whisper_info["checked"] = True + elif ttype == "mlx-whisper": + whisper_info["downloaded"] = _check_mlx_whisper_model_exists(size) + whisper_info["checked"] = True + except Exception: + pass + + return R.success(data={ + "backend": "ok", + "ffmpeg": ffmpeg_status, + "db": db_status, + "whisper_model": whisper_info, + }) + + +@router.get("/sys_check") +async def sys_check(): + """轻量存活判定:后端进程能响应这个 endpoint 就算「起来了」,不查外部依赖。 + + 给桌面端 useCheckBackend / Tauri ready-probe 用。 + """ + return R.success() + + +@router.get("/deploy_status") +async def deploy_status(): + """返回部署监控所需的所有状态信息。 + + 所有子项都用 try 包起来——监控页本身不应该被任何一个子项打死。 + 特别是 torch:它只在 fast-whisper 路径用得到,用 Groq / 必剪 / 快手在线 + 引擎的轻量部署完全可以不装,那种情况这个 endpoint 不应该 500。 + """ + import os + + # CUDA 状态 + try: + import torch + cuda_available = torch.cuda.is_available() + cuda_info = { + "available": cuda_available, + "torch_installed": True, + "version": torch.version.cuda if cuda_available else None, + "gpu_name": torch.cuda.get_device_name(0) if cuda_available else None, + } + except Exception: + cuda_info = { + "available": False, + "torch_installed": False, + "version": None, + "gpu_name": None, + } + + # Whisper 模型 / 转写器配置 + 本地下载状态 + try: + transcriber_cfg = transcriber_config_manager.get_config() + size = transcriber_cfg["whisper_model_size"] + ttype = transcriber_cfg["transcriber_type"] + if ttype == "fast-whisper": + downloaded = _check_whisper_model_exists(size, "whisper") + elif ttype == "mlx-whisper": + downloaded = _check_mlx_whisper_model_exists(size) + else: + downloaded = False # 在线引擎无下载概念 + whisper_info = { + "model_size": size, + "transcriber_type": ttype, + "downloaded": downloaded, + } + except Exception: + whisper_info = {"model_size": None, "transcriber_type": None, "downloaded": False} + + # FFmpeg 状态 + try: + ensure_ffmpeg_or_raise() + ffmpeg_ok = True + except Exception: + ffmpeg_ok = False + + return R.success(data={ + "backend": {"status": "running", "port": int(os.getenv("BACKEND_PORT", 8483))}, + "cuda": cuda_info, + "whisper": whisper_info, + "ffmpeg": {"available": ffmpeg_ok}, + }) diff --git a/backend/app/routers/feishu.py b/backend/app/routers/feishu.py new file mode 100644 index 0000000000000000000000000000000000000000..5d63e575469556639e9a4ffc24a74f18458173cb --- /dev/null +++ b/backend/app/routers/feishu.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +import logging +import os +from datetime import datetime +from typing import Optional + +from fastapi import APIRouter +from pydantic import BaseModel + +from app.services import feishu_pusher +from app.services.feishu_config_manager import FeishuConfigManager +from app.services.feishu_service import FeishuError +from app.utils.response import ResponseWrapper as R + +logger = logging.getLogger(__name__) + +router = APIRouter() + +# 后端对外可访问地址:用于把笔记里的 /static、/uploads 相对图片补成绝对链接, +# 让飞书导入时有机会抓到图(与 app/services/note.py 的 BACKEND_BASE_URL 同源)。 +_API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost") +_BACKEND_PORT = os.getenv("BACKEND_PORT", "8483") +BACKEND_ORIGIN = f"{_API_BASE_URL}:{_BACKEND_PORT}" + + +class FeishuConfigRequest(BaseModel): + app_id: Optional[str] = None + app_secret: Optional[str] = None + folder_token: Optional[str] = None + base_url: Optional[str] = None + auto_push: Optional[bool] = None + enabled: Optional[bool] = None + push_backend: Optional[str] = None # "auto" | "rest" | "cli" + cli_path: Optional[str] = None + + +class FeishuPushRequest(BaseModel): + task_id: str + version_id: Optional[str] = None + + +def push_task_to_feishu(task_id: str, version_id: Optional[str] = None) -> dict: + """读取已生成笔记 → 导入飞书文档 → 把结果写回笔记 JSON,返回 feishu 信息。 + + 被「手动推送」接口与「生成后自动推送」共用。任何失败都抛 FeishuError, + 由调用方决定是返回错误还是仅记日志(自动推送场景不应中断主流程)。 + + 直接读写原始 JSON(不走 _read/_write_note_json 的版本归一化),避免把 + markdown 字符串就地改写成版本数组、抹掉单版本笔记的 model/style 元信息。 + """ + import json + + # 延迟导入避免与 note 路由的循环依赖;note 路由不在模块级 import 本模块 + from app.routers.note import NOTE_OUTPUT_DIR, _pick_markdown_version + + path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json") + if not os.path.exists(path): + raise FeishuError(f"笔记不存在或尚未生成完成:{task_id}") + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + + # _pick_markdown_version 兼容旧 str 与多版本 list 两种格式,仅读取不改写 + content = _pick_markdown_version(data.get("markdown"), version_id) + if not (content or "").strip(): + raise FeishuError("笔记内容为空,无法推送到飞书") + + audio_meta = data.get("audio_meta") or {} + raw_info = audio_meta.get("raw_info") or {} + title = audio_meta.get("title") or raw_info.get("title") or f"VideoMemo 笔记 {task_id[:8]}" + + result = feishu_pusher.push_markdown( + title=title, + markdown=content, + image_base_url=BACKEND_ORIGIN, + ) + feishu_info = { + "url": result.get("url", ""), + "token": result.get("token", ""), + "type": result.get("type", "docx"), + "title": result.get("title", title), + "pushed_at": datetime.now().isoformat(timespec="seconds"), + } + data["feishu"] = feishu_info + with open(path, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + return feishu_info + + +def auto_push_if_enabled(task_id: str) -> None: + """笔记生成成功后调用:开启「自动推送」时把笔记推到飞书。失败只记日志,不影响主流程。""" + try: + if not FeishuConfigManager().is_auto_push_enabled(): + return + info = push_task_to_feishu(task_id) + logger.info(f"飞书自动推送成功 (task_id={task_id}) -> {info.get('url')}") + except Exception as e: + logger.warning(f"飞书自动推送失败 (task_id={task_id}):{e}") + + +# ─── 配置 ──────────────────────────────────────────────────────────────────── +@router.get("/feishu_config") +def get_feishu_config(): + return R.success(FeishuConfigManager().get_public_config()) + + +@router.post("/feishu_config") +def update_feishu_config(data: FeishuConfigRequest): + cfg = FeishuConfigManager().update_config( + enabled=data.enabled, + auto_push=data.auto_push, + app_id=data.app_id, + app_secret=data.app_secret, + folder_token=data.folder_token, + base_url=data.base_url, + push_backend=data.push_backend, + cli_path=data.cli_path, + ) + return R.success(cfg, msg="飞书配置已保存") + + +@router.post("/feishu_test") +def test_feishu_connection(): + try: + result = feishu_pusher.test_connection() + return R.success(result, msg=result.get("message", "连接成功")) + except FeishuError as e: + return R.error(msg=e.message, code=400) + except Exception as e: + logger.error(f"飞书连接测试异常: {e}", exc_info=True) + return R.error(msg=f"连接失败:{e}", code=400) + + +# ─── 推送笔记 ───────────────────────────────────────────────────────────────── +@router.post("/feishu_push") +def push_note_to_feishu(data: FeishuPushRequest): + try: + info = push_task_to_feishu(data.task_id, data.version_id) + return R.success(info, msg="已推送到飞书文档") + except FeishuError as e: + return R.error(msg=e.message, code=400) + except Exception as e: + logger.error(f"飞书推送失败 (task_id={data.task_id}): {e}", exc_info=True) + return R.error(msg=f"推送失败:{e}", code=500) diff --git a/backend/app/routers/flashcard.py b/backend/app/routers/flashcard.py new file mode 100644 index 0000000000000000000000000000000000000000..159097c495780c6d9c5639174d87269539b80203 --- /dev/null +++ b/backend/app/routers/flashcard.py @@ -0,0 +1,104 @@ +import json +import re + +from fastapi import APIRouter +from pydantic import BaseModel + +from app.gpt.gpt_factory import GPTFactory +from app.gpt.utils import strip_think_blocks +from app.models.model_config import ModelConfig +from app.services.provider import ProviderService +from app.utils.logger import get_logger +from app.utils.response import ResponseWrapper as R + +logger = get_logger(__name__) + +router = APIRouter() + + +class FlashcardRequest(BaseModel): + content: str + provider_id: str + model_name: str + count: int = 10 + + +SYSTEM_PROMPT = """你是一个学习卡片生成助手。请根据用户提供的笔记内容,提炼关键知识点,生成问答式记忆闪卡。 + +要求: +- 每张卡片包含 front(问题/正面)和 back(答案/背面) +- 问题应聚焦核心概念、定义、关键结论,便于主动回忆 +- 答案简洁准确,控制在 1~3 句话 +- 最多生成 {count} 张,不要硬凑,宁缺毋滥 +- 严格只输出 JSON 数组,不要任何额外说明或 markdown 代码块,格式: +[{{"front": "问题", "back": "答案"}}]""" + + +def _parse_cards(text: str) -> list[dict]: + """从 LLM 输出中解析出卡片数组,容忍代码块包裹。""" + cleaned = text.strip() + # 去掉 ```json ... ``` 包裹 + cleaned = re.sub(r"^```(?:json)?\s*|\s*```$", "", cleaned, flags=re.IGNORECASE) + try: + data = json.loads(cleaned) + except json.JSONDecodeError: + # 退化:抓取第一个 JSON 数组 + match = re.search(r"\[.*\]", cleaned, flags=re.DOTALL) + if not match: + return [] + try: + data = json.loads(match.group(0)) + except json.JSONDecodeError: + return [] + + cards = [] + for item in data if isinstance(data, list) else []: + front = (item or {}).get("front") + back = (item or {}).get("back") + if front and back: + cards.append({"front": str(front), "back": str(back)}) + return cards + + +@router.post("/flashcards/generate") +def generate_flashcards(data: FlashcardRequest): + """根据笔记内容用 LLM 生成问答闪卡。""" + content = data.content.strip() + if not content: + return R.error(msg="笔记内容为空,无法生成闪卡") + + provider = ProviderService.get_provider_by_id(data.provider_id) + if not provider: + return R.error(msg=f"未找到模型供应商: {data.provider_id}") + + config = ModelConfig( + api_key=provider["api_key"], + base_url=provider["base_url"], + model_name=data.model_name, + provider=provider["type"], + name=provider["name"], + ) + gpt = GPTFactory.from_config(config) + + # 控制输入长度,避免超长 token + max_chars = 12000 + snippet = content[:max_chars] + + try: + response = gpt.client.chat.completions.create( + model=gpt.model, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT.format(count=data.count)}, + {"role": "user", "content": snippet}, + ], + temperature=0.4, + ) + raw = strip_think_blocks(response.choices[0].message.content) + cards = _parse_cards(raw) + if not cards: + logger.warning(f"闪卡解析为空,原始输出: {raw[:200]}") + return R.error(msg="未能生成有效闪卡,请重试") + return R.success(data={"cards": cards}) + except Exception as e: + logger.error(f"生成闪卡失败: {e}", exc_info=True) + return R.error(msg=f"生成闪卡失败: {str(e)}") diff --git a/backend/app/routers/hot_videos.py b/backend/app/routers/hot_videos.py new file mode 100644 index 0000000000000000000000000000000000000000..5532f657c2ffe5fee0e9bd0d64c49132f29cc071 --- /dev/null +++ b/backend/app/routers/hot_videos.py @@ -0,0 +1,20 @@ +from fastapi import APIRouter, Query + +from app.services.hot_videos import fetch_hot_video_payload +from app.utils.response import ResponseWrapper as R + +router = APIRouter() + + +@router.get("/hot_videos") +def get_hot_videos( + platform: str = Query("all"), + limit: int = Query(12, ge=1, le=30), + force: bool = Query(False), +): + try: + return R.success(fetch_hot_video_payload(platform=platform, limit=limit, force=force)) + except ValueError as exc: + return R.error(msg=str(exc), code=400) + except Exception as exc: + return R.error(msg=f"热点视频获取失败: {exc}") diff --git a/backend/app/routers/model.py b/backend/app/routers/model.py new file mode 100644 index 0000000000000000000000000000000000000000..a0630a8ba2741eac1eeaad9fc15e5d939049e9f4 --- /dev/null +++ b/backend/app/routers/model.py @@ -0,0 +1,51 @@ +from fastapi import APIRouter +from pydantic import BaseModel + +from app.services.model import ModelService +from app.utils.response import ResponseWrapper as R +router = APIRouter() +modelService = ModelService() +class CreateModelRequest(BaseModel): + provider_id: str + model_name: str + +# 返回体:模型信息 +class ModelItem(BaseModel): + id: int + model_name: str +@router.get("/model_list") +def model_list(): + try: + return R.success(modelService.get_all_models(True),msg="获取模型列表成功") + except Exception as e: + return R.error(e) +@router.get("/models/delete/{model_id}") +def delete_model(model_id: int): + try: + success = modelService.delete_model_by_id(model_id) + if success: + return R.success(msg="模型删除成功") + else: + return R.error("模型不存在或删除失败") + except Exception as e: + return R.error(f"删除模型失败: {e}") +@router.get("/model_list/{provider_id}") +def model_list(provider_id): + + return R.success(modelService.get_all_models_by_id(provider_id)) + + +@router.post("/models") +def create_model(data: CreateModelRequest): + success = ModelService.add_new_model(data.provider_id, data.model_name) + if not success: + return R.error("模型添加失败") + return R.success(msg="模型添加成功") + +@router.get("/model_enable/{provider_id}") +def get_enabled_models_by_provider(provider_id: str): + try: + models = modelService.get_enabled_models_by_provider(provider_id) + return R.success(models, msg="获取启用模型成功") + except Exception as e: + return R.error(f"获取启用模型失败: {e}") \ No newline at end of file diff --git a/backend/app/routers/note.py b/backend/app/routers/note.py new file mode 100644 index 0000000000000000000000000000000000000000..cd6409ad21f7bcdd3b2f8312fe4ae44cce88cc51 --- /dev/null +++ b/backend/app/routers/note.py @@ -0,0 +1,634 @@ +# app/routers/note.py +import json +import os +import uuid +from pathlib import Path +from typing import Optional +from urllib.parse import urlparse + +from fastapi import APIRouter, HTTPException, BackgroundTasks, UploadFile, File +from pydantic import BaseModel, validator, field_validator +from dataclasses import asdict + +from app.db.video_task_dao import get_task_by_video +from app.enmus.exception import NoteErrorEnum +from app.enmus.note_enums import DownloadQuality +from app.exceptions.note import NoteError +from app.services.note import NoteGenerator, logger +from app.services import task_control +from app.services.task_serial_executor import task_serial_executor +from app.utils.response import ResponseWrapper as R +from app.utils.url_parser import extract_video_id +from app.validators.video_url_validator import is_supported_video_url +from fastapi import APIRouter, Request, HTTPException +from fastapi.responses import StreamingResponse, FileResponse, Response +import httpx +from app.enmus.task_status_enums import TaskStatus + +# from app.services.downloader import download_raw_audio +# from app.services.whisperer import transcribe_audio + +router = APIRouter() + + +class RecordRequest(BaseModel): + video_id: str + platform: str + + +class VideoRequest(BaseModel): + video_url: str + platform: str + quality: DownloadQuality + screenshot: Optional[bool] = False + link: Optional[bool] = False + model_name: str + provider_id: str + task_id: Optional[str] = None + format: Optional[list] = [] + style: str = None + extras: Optional[str]=None + video_understanding: Optional[bool] = False + video_interval: Optional[int] = 0 + grid_size: Optional[list] = [] + # 客户端(如浏览器插件)已经在用户浏览器里抓到字幕,直接传给后端复用, + # 跳过 download_subtitles 和音频转写。形如: + # {"language": "zh", "full_text": "...", "segments": [{"start","end","text"}, ...]} + prefetched_transcript: Optional[dict] = None + + @field_validator("video_url") + def validate_supported_url(cls, v): + url = str(v) + parsed = urlparse(url) + if parsed.scheme in ("http", "https"): + # 是网络链接,继续用原有平台校验 + if not is_supported_video_url(url): + raise NoteError(code=NoteErrorEnum.PLATFORM_NOT_SUPPORTED.code, + message=NoteErrorEnum.PLATFORM_NOT_SUPPORTED.message) + + return v + + +NOTE_OUTPUT_DIR = os.getenv("NOTE_OUTPUT_DIR", "note_results") +UPLOAD_DIR = "uploads" + + +def save_note_to_file(task_id: str, note): + os.makedirs(NOTE_OUTPUT_DIR, exist_ok=True) + with open(os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json"), "w", encoding="utf-8") as f: + json.dump(asdict(note), f, ensure_ascii=False, indent=2) + + +def _persist_prefetched_transcript(task_id: str, transcript: dict) -> None: + """把客户端预取的字幕写到 NoteGenerator 期望的转写缓存文件里。 + + NoteGenerator.generate 会优先读 _transcript.json,命中即跳过 download_subtitles + 与音频转写流程。要求字段:language(可空)/full_text/segments[{start,end,text}] + """ + segments = transcript.get("segments") or [] + cleaned_segments = [] + for s in segments: + text = (s.get("text") or "").strip() + if not text: + continue + cleaned_segments.append({ + "start": float(s.get("start", 0)), + "end": float(s.get("end", 0)), + "text": text, + }) + if not cleaned_segments: + raise ValueError("prefetched_transcript 没有可用的 segments") + + full_text = transcript.get("full_text") or " ".join(s["text"] for s in cleaned_segments) + payload = { + "language": transcript.get("language") or "zh", + "full_text": full_text, + "segments": cleaned_segments, + } + + os.makedirs(NOTE_OUTPUT_DIR, exist_ok=True) + target = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}_transcript.json") + with open(target, "w", encoding="utf-8") as f: + json.dump(payload, f, ensure_ascii=False, indent=2) + logger.info(f"已写入客户端预取字幕缓存: {target} ({len(cleaned_segments)} 段)") + + +def run_note_task(task_id: str, video_url: str, platform: str, quality: DownloadQuality, + link: bool = False, screenshot: bool = False, model_name: str = None, provider_id: str = None, + _format: list = None, style: str = None, extras: str = None, video_understanding: bool = False, + video_interval=0, grid_size=[] + ): + + if not model_name or not provider_id: + raise HTTPException(status_code=400, detail="请选择模型和提供者") + + def _execute_note_task(): + return NoteGenerator().generate( + video_url=video_url, + platform=platform, + quality=quality, + task_id=task_id, + model_name=model_name, + provider_id=provider_id, + link=link, + _format=_format, + style=style, + extras=extras, + screenshot=screenshot, + video_understanding=video_understanding, + video_interval=video_interval, + grid_size=grid_size, + ) + + logger.info(f"任务进入执行队列 (task_id={task_id})") + note = task_serial_executor.run(_execute_note_task) + logger.info(f"Note generated: {task_id}") + if not note or not note.markdown: + logger.warning(f"任务 {task_id} 执行失败,跳过保存") + return + save_note_to_file(task_id, note) + + # 自动建立向量索引(用于 AI 问答),失败不影响笔记生成 + try: + from app.services.vector_store import VectorStoreManager + VectorStoreManager().index_task(task_id) + except Exception as e: + logger.warning(f"向量索引失败(不影响笔记): {e}") + + # 生成后自动推送到飞书文档(仅在「设置 → 飞书推送」开启了自动推送时触发), + # 内部已吞掉自身异常,这里再兜一层防止 import 失败影响主流程 + try: + from app.routers.feishu import auto_push_if_enabled + auto_push_if_enabled(task_id) + except Exception as e: + logger.warning(f"飞书自动推送调度失败(不影响笔记): {e}") + + +@router.post('/delete_task') +def delete_task(data: RecordRequest): + try: + # TODO: 待持久化完成 + # NoteGenerator().delete_note(video_id=data.video_id, platform=data.platform) + return R.success(msg='删除成功') + except Exception as e: + return R.error(msg=e) + + +def _safe_filename(name: str, fallback: str = "note") -> str: + """剔除文件名里 OS 不允许的字符,长度截到 80。""" + import re as _re + cleaned = _re.sub(r'[\\/:*?"<>|\r\n\t]', "", (name or "").strip()) + cleaned = cleaned.strip(". ") + return (cleaned[:80] or fallback) + + +def _normalize_versions(markdown_field, fallback_meta: Optional[dict] = None) -> list: + """把 markdown 字段统一成版本数组形式。 + - str:包成一个 source='generated' 的版本(旧笔记自动迁移) + - list:直接返回(已经是多版本) + - 其它:空数组 + """ + if isinstance(markdown_field, list): + return markdown_field + if isinstance(markdown_field, str) and markdown_field.strip(): + meta = fallback_meta or {} + return [{ + "ver_id": "v1", + "content": markdown_field, + "style": meta.get("style", ""), + "model_name": meta.get("model_name", ""), + "source": "generated", + "created_at": meta.get("created_at", ""), + }] + return [] + + +def _pick_markdown_version(markdown_field, version_id: Optional[str]) -> str: + """从笔记的 markdown 字段里挑出对应版本的字符串内容。 + 向后兼容旧格式(markdown 是 str)与新多版本格式(list[VersionNote])。 + """ + versions = _normalize_versions(markdown_field) + if not versions: + return "" + if version_id: + for v in versions: + if v.get("ver_id") == version_id: + return v.get("content") or v.get("markdown") or "" + # 默认取最新一版(按 created_at 排序,缺省取末尾) + try: + latest = sorted( + versions, + key=lambda v: v.get("created_at") or "", + reverse=True, + )[0] + except Exception: + latest = versions[-1] + return latest.get("content") or latest.get("markdown") or "" + + +def _read_note_json(task_id: str) -> dict: + """读笔记 JSON 文件,把 markdown 字段就地归一化成版本数组。文件不存在抛 HTTPException。""" + path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json") + if not os.path.exists(path): + raise HTTPException(status_code=404, detail="笔记不存在") + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + data["markdown"] = _normalize_versions(data.get("markdown")) + return data + + +def _write_note_json(task_id: str, data: dict) -> None: + """写回笔记 JSON 文件。markdown 一定是 list 形式。""" + path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json") + os.makedirs(NOTE_OUTPUT_DIR, exist_ok=True) + with open(path, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + +def _append_version( + versions: list, + content: str, + source: str, + model_name: str = "", + style: str = "", +) -> dict: + """构造一个新版本并 append 到 versions,返回新版本的 dict(含 ver_id)。""" + from datetime import datetime + new_ver = { + "ver_id": uuid.uuid4().hex[:12], + "content": content, + "style": style or "", + "model_name": model_name or "", + "source": source, # 'generated' | 'manual' | 'repolish' + "created_at": datetime.now().isoformat(timespec="seconds"), + } + versions.append(new_ver) + return new_ver + + +def _trigger_reindex(task_id: str, background_tasks: BackgroundTasks) -> None: + """编辑/润色/删版本后异步重建向量索引,避免知识检索查到旧内容。失败不影响主流程。""" + def _do(): + try: + from app.services.vector_store import VectorStoreManager + VectorStoreManager().index_task(task_id) + except Exception as e: + logger.warning(f"重建向量索引失败 task_id={task_id}: {e}") + background_tasks.add_task(_do) + + +@router.get("/export_note/{task_id}") +def export_note(task_id: str, format: str = "markdown", version_id: Optional[str] = None): + """ + 导出指定笔记为多种格式。 + - format: markdown / pdf / html / word / docx(image / png 暂不支持) + - version_id: 多版本时指定某一版;不传取最新版(v1 单版兼容旧 str 格式) + """ + note_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json") + if not os.path.exists(note_path): + raise HTTPException(status_code=404, detail="笔记不存在") + + with open(note_path, "r", encoding="utf-8") as f: + note = json.load(f) + + content = _pick_markdown_version(note.get("markdown"), version_id) + if not content.strip(): + raise HTTPException(status_code=404, detail="笔记内容为空") + + audio_meta = note.get("audio_meta", {}) or {} + raw_title = audio_meta.get("title") or (audio_meta.get("raw_info") or {}).get("title") or task_id + title = _safe_filename(raw_title, fallback=task_id) + + fmt = format.lower().strip() + if fmt == "markdown" or fmt == "md": + # 直接返回 Markdown 文本,让浏览器另存。中文文件名按 RFC 5987 编码, + # 否则 Chrome / Safari 会丢掉非 ASCII 的 filename。 + from urllib.parse import quote + ascii_fallback = f"{task_id}.md" + cd = ( + f"attachment; filename=\"{ascii_fallback}\"; " + f"filename*=UTF-8''{quote(title + '.md', safe='')}" + ) + return Response( + content=content, + media_type="text/markdown; charset=utf-8", + headers={"Content-Disposition": cd}, + ) + + # 其它格式落到 ExportUtils(pdf / html / word / docx) + try: + from app.utils.export import ExportUtils + save_path = ExportUtils().export(fmt, title=title, content=content) + except ValueError as ve: + raise HTTPException(status_code=400, detail=str(ve)) + except Exception as e: + logger.error(f"导出失败 task_id={task_id} format={fmt}: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"导出失败: {e}") + + if not os.path.exists(save_path): + raise HTTPException(status_code=500, detail="导出文件未生成") + + media_map = { + "pdf": "application/pdf", + "html": "text/html; charset=utf-8", + "word": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + } + return FileResponse( + path=save_path, + media_type=media_map.get(fmt, "application/octet-stream"), + filename=os.path.basename(save_path), + ) + + +# ── 笔记多版本:手动编辑 / AI 重新润色 / 删版本 ───────────────── + + +class ManualEditRequest(BaseModel): + content: str + style: Optional[str] = None # 沿用上一版风格,不传也行 + + +class RepolishRequest(BaseModel): + style: Optional[str] = None + extras: Optional[str] = None + provider_id: str + model_name: str + + +def _versions_response_payload(task_id: str, data: dict, new_version: Optional[dict] = None) -> dict: + """统一的响应负载:把整篇笔记 + 新版本 id 返回给前端,前端可直接刷整页状态。""" + return { + "task_id": task_id, + "markdown": data["markdown"], # 已被归一化成 list + "current_ver_id": new_version["ver_id"] if new_version else None, + } + + +@router.patch("/note/{task_id}") +def update_note(task_id: str, body: ManualEditRequest, background_tasks: BackgroundTasks): + """手动编辑笔记 —— 把当前编辑的内容作为新版本追加到 markdown 数组。""" + if not body.content.strip(): + raise HTTPException(status_code=400, detail="笔记内容不能为空") + data = _read_note_json(task_id) + versions = data["markdown"] # 已 normalize + + # 沿用最新版本的 model_name + style 作默认值(保持元数据一致性) + last = versions[-1] if versions else {} + new_ver = _append_version( + versions, + content=body.content, + source="manual", + model_name=last.get("model_name", ""), + style=body.style or last.get("style", ""), + ) + data["markdown"] = versions + _write_note_json(task_id, data) + _trigger_reindex(task_id, background_tasks) + return R.success(data=_versions_response_payload(task_id, data, new_ver)) + + +@router.post("/note/{task_id}/repolish") +def repolish_note(task_id: str, body: RepolishRequest, background_tasks: BackgroundTasks): + """AI 重新润色 —— 用现有 markdown + transcript 调 LLM 生成新风格,作为新版本追加。""" + try: + new_content = NoteGenerator().repolish( + task_id=task_id, + style=body.style, + extras=body.extras, + provider_id=body.provider_id, + model_name=body.model_name, + ) + except NoteError as e: + raise HTTPException(status_code=400, detail=e.message) + except Exception as e: + logger.error(f"重新润色失败 task_id={task_id}: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"重新润色失败: {e}") + + if not new_content or not new_content.strip(): + raise HTTPException(status_code=500, detail="模型未返回内容") + + data = _read_note_json(task_id) + versions = data["markdown"] + new_ver = _append_version( + versions, + content=new_content, + source="repolish", + model_name=body.model_name, + style=body.style or "", + ) + data["markdown"] = versions + _write_note_json(task_id, data) + _trigger_reindex(task_id, background_tasks) + return R.success(data=_versions_response_payload(task_id, data, new_ver)) + + +@router.delete("/note/{task_id}/version/{ver_id}") +def delete_version(task_id: str, ver_id: str, background_tasks: BackgroundTasks): + """删除某个版本。至少保留一个版本,否则拒绝。""" + data = _read_note_json(task_id) + versions = data["markdown"] + if len(versions) <= 1: + raise HTTPException(status_code=400, detail="至少需要保留一个版本") + + new_versions = [v for v in versions if v.get("ver_id") != ver_id] + if len(new_versions) == len(versions): + raise HTTPException(status_code=404, detail=f"版本不存在:{ver_id}") + + data["markdown"] = new_versions + _write_note_json(task_id, data) + _trigger_reindex(task_id, background_tasks) + return R.success(data=_versions_response_payload(task_id, data, new_versions[-1])) + + +@router.post("/upload") +async def upload(file: UploadFile = File(...)): + os.makedirs(UPLOAD_DIR, exist_ok=True) + file_location = os.path.join(UPLOAD_DIR, file.filename) + + with open(file_location, "wb+") as f: + f.write(await file.read()) + + # 假设你静态目录挂载了 /uploads + return R.success({"url": f"/uploads/{file.filename}"}) + + +@router.post("/generate_note") +def generate_note(data: VideoRequest, background_tasks: BackgroundTasks): + try: + # 就绪门禁:本地转写引擎(fast-whisper / mlx-whisper)必须等模型下载完才能跑视频, + # 否则任务会卡在首次下载(慢 / OOM / 截断),用户只看到一个静默失败的任务。 + # 客户端已抓好字幕(prefetched_transcript)则不需要转写,跳过检查。 + if not data.prefetched_transcript: + from app.services.transcriber_config_manager import TranscriberConfigManager + readiness = TranscriberConfigManager().is_model_ready() + if not readiness["ready"]: + logger.warning(f"拒绝 generate_note:{readiness['reason']}") + return R.error( + msg=readiness["reason"], + code=300102, + data={ + "reason": "transcriber_model_not_ready", + "transcriber_type": readiness["transcriber_type"], + "model_size": readiness["model_size"], + "downloading": readiness["downloading"], + }, + ) + + video_id = extract_video_id(data.video_url, data.platform) + # if not video_id: + # raise HTTPException(status_code=400, detail="无法提取视频 ID") + # existing = get_task_by_video(video_id, data.platform) + # if existing: + # return R.error( + # msg='笔记已生成,请勿重复发起', + # + # ) + if data.task_id: + # 如果传了task_id,说明是重试! + task_id = data.task_id + logger.info(f"重试模式,复用已有 task_id={task_id}") + else: + # 正常新建任务 + task_id = str(uuid.uuid4()) + + # 统一先写入 PENDING,表示已进入队列等待串行执行 + NoteGenerator()._update_status(task_id, TaskStatus.PENDING) + + # 客户端已经抓好字幕的话,写到转写缓存文件,NoteGenerator 的 cache-hit 逻辑会直接用上 + if data.prefetched_transcript: + try: + _persist_prefetched_transcript(task_id, data.prefetched_transcript) + except Exception as e: + logger.warning(f"写入预取字幕失败 (task_id={task_id}): {e}") + + background_tasks.add_task(run_note_task, task_id, data.video_url, data.platform, data.quality, data.link, + data.screenshot, data.model_name, data.provider_id, data.format, data.style, + data.extras, data.video_understanding, data.video_interval, data.grid_size) + return R.success({"task_id": task_id}) + except Exception as e: + # 用业务错误格式返回(而不是 HTTPException 500): + # 前端拦截器读的是 msg 字段,500 的 detail 会被吞成笼统的「服务器错误,请稍后再试」, + # 用户看不到「转写引擎不可用,请安装/切换」这类可行动的原因。 + logger.error(f"generate_note 入口失败: {e}", exc_info=True) + return R.error(msg=str(e)) + + +@router.get("/task_status/{task_id}") +def get_task_status(task_id: str): + status_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.status.json") + result_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json") + + # 优先读状态文件 + if os.path.exists(status_path): + with open(status_path, "r", encoding="utf-8") as f: + status_content = json.load(f) + + status = status_content.get("status") + message = status_content.get("message", "") + paused = bool(status_content.get("paused", False)) + cache = status_content.get("cache") + + if status == TaskStatus.SUCCESS.value: + # 成功状态的话,继续读取最终笔记内容 + if os.path.exists(result_path): + with open(result_path, "r", encoding="utf-8") as rf: + result_content = json.load(rf) + return R.success({ + "status": status, + "result": result_content, + "message": message, + "cache": cache, + "task_id": task_id + }) + else: + # 理论上不会出现,保险处理 + return R.success({ + "status": TaskStatus.PENDING.value, + "message": "任务完成,但结果文件未找到", + "cache": cache, + "task_id": task_id + }) + + if status == TaskStatus.FAILED.value: + return R.error(message or "任务失败", code=500) + + # 处理中状态 + return R.success({ + "status": status, + "message": message, + "paused": paused, + "cache": cache, + "task_id": task_id + }) + + # 没有状态文件,但有结果 + if os.path.exists(result_path): + with open(result_path, "r", encoding="utf-8") as f: + result_content = json.load(f) + return R.success({ + "status": TaskStatus.SUCCESS.value, + "result": result_content, + "task_id": task_id + }) + + # 什么都没有,默认PENDING + return R.success({ + "status": TaskStatus.PENDING.value, + "message": "任务排队中", + "task_id": task_id + }) + + +class TaskControlRequest(BaseModel): + task_id: str + action: str # 'pause' | 'resume' + + +@router.post("/task_control") +def task_control_endpoint(data: TaskControlRequest): + """暂停 / 继续任务。暂停仅在步骤之间生效(总结阶段之后不可暂停)。""" + if data.action == "pause": + task_control.pause(data.task_id) + return R.success(msg="已请求暂停") + if data.action == "resume": + task_control.resume(data.task_id) + return R.success(msg="已继续") + return R.error(msg="无效的操作") + + +# Referer 选择逻辑移到 cover_helper 统一维护(image_proxy 与笔记生成时的封面本地化共用) +from app.utils.cover_helper import pick_referer as _pick_referer + + +@router.get("/image_proxy") +async def image_proxy(request: Request, url: str): + headers = { + "User-Agent": request.headers.get( + "User-Agent", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 " + "(KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + ), + } + referer = _pick_referer(url) + if referer: + headers["Referer"] = referer + + try: + async with httpx.AsyncClient(timeout=10.0, follow_redirects=True) as client: + resp = await client.get(url, headers=headers) + + if resp.status_code != 200: + raise HTTPException(status_code=resp.status_code, detail="图片获取失败") + + content_type = resp.headers.get("Content-Type", "image/jpeg") + return StreamingResponse( + resp.aiter_bytes(), + media_type=content_type, + headers={ + "Cache-Control": "public, max-age=86400", # 缓存一天 + "Content-Type": content_type, + } + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/backend/app/routers/notification.py b/backend/app/routers/notification.py new file mode 100644 index 0000000000000000000000000000000000000000..ca069f8f75b70eab5ecfa965970b1675dcc456a4 --- /dev/null +++ b/backend/app/routers/notification.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from typing import Optional + +from fastapi import APIRouter +from pydantic import BaseModel + +from app.db.trend_subscription_dao import ( + create_channel, + delete_channel, + get_channel, + list_channels, + update_channel, +) +from app.services.notification import NotificationService +from app.utils.response import ResponseWrapper as R + +router = APIRouter() + + +class CreateChannelRequest(BaseModel): + name: str + type: str # "webhook" | "bark" | "email" + config: dict # type-specific + + +class UpdateChannelRequest(BaseModel): + name: Optional[str] = None + type: Optional[str] = None + config: Optional[dict] = None + enabled: Optional[bool] = None + + +def _channel_to_dict(ch) -> dict: + import json + + return { + "id": ch.id, + "name": ch.name, + "type": ch.type, + "config": json.loads(ch.config or "{}"), + "enabled": ch.enabled, + "created_at": ch.created_at.isoformat() if ch.created_at else None, + "updated_at": ch.updated_at.isoformat() if ch.updated_at else None, + } + + +@router.get("/notification_channels") +def get_channels(): + channels = list_channels() + return R.success([_channel_to_dict(c) for c in channels]) + + +@router.get("/notification_channels/{channel_id}") +def get_channel_detail(channel_id: int): + channel = get_channel(channel_id) + if channel is None: + return R.error(msg=f"通知通道 {channel_id} 不存在", code=404) + return R.success(_channel_to_dict(channel)) + + +@router.post("/notification_channels") +def create_notification_channel(data: CreateChannelRequest): + channel = create_channel(name=data.name, channel_type=data.type, config=data.config) + return R.success(_channel_to_dict(channel)) + + +@router.put("/notification_channels/{channel_id}") +def update_notification_channel(channel_id: int, data: UpdateChannelRequest): + channel = update_channel( + channel_id=channel_id, + name=data.name, + channel_type=data.type, + config=data.config, + enabled=data.enabled, + ) + if channel is None: + return R.error(msg=f"通知通道 {channel_id} 不存在", code=404) + return R.success(_channel_to_dict(channel)) + + +@router.delete("/notification_channels/{channel_id}") +def delete_notification_channel(channel_id: int): + ok = delete_channel(channel_id) + if not ok: + return R.error(msg=f"通知通道 {channel_id} 不存在", code=404) + return R.success(msg="已删除") + + +@router.post("/notification_channels/{channel_id}/test") +def test_notification_channel(channel_id: int): + result = NotificationService().send_test(channel_id) + if result.get("success"): + return R.success(result, msg="测试通知发送成功") + return R.error(msg=result.get("error", "发送失败"), code=400) diff --git a/backend/app/routers/provider.py b/backend/app/routers/provider.py new file mode 100644 index 0000000000000000000000000000000000000000..1d1152020fef0073c28a34f02b772db3980f843b --- /dev/null +++ b/backend/app/routers/provider.py @@ -0,0 +1,97 @@ +from typing import Optional +from fastapi import APIRouter +from pydantic import BaseModel + +from app.exceptions.provider import ProviderError +from app.models.model_config import ModelConfig +from app.services.model import ModelService +from app.utils.response import ResponseWrapper as R +from app.services.provider import ProviderService + +router = APIRouter() + +# 新增 type 字段 +class ProviderRequest(BaseModel): + name: str + api_key: str + base_url: str + logo: Optional[str] = None + type: str + +class TestRequest(BaseModel): + id: str + # 可选:指定用哪个 model 跑连通性测试;不传则用该 provider 在 DB 里的第一个模型 + model: Optional[str] = None +class ProviderUpdateRequest(BaseModel): + id: str + name: Optional[str] = None + api_key: Optional[str] = None + base_url: Optional[str] = None + logo: Optional[str] = None + type: Optional[str] = None + enabled:Optional[int] = None + +@router.post("/add_provider") +def add_provider(data: ProviderRequest): + try: + res = ProviderService.add_provider( + name=data.name, + api_key=data.api_key, + base_url=data.base_url, + logo=data.logo, + type_=data.type + ) + return R.success(msg='添加模型供应商成功',data=res) + except Exception as e: + return R.error(msg=e) + +@router.get("/get_all_providers") +def get_all_providers(): + try: + res = ProviderService.get_all_providers_safe() + return R.success(data=res) + except Exception as e: + return R.error(msg=e) + +@router.get("/get_provider_by_id/{id}") +def get_provider_by_id(id: str): + try: + res = ProviderService.get_provider_by_id_safe(id) + return R.success(data=res) + except Exception as e: + return R.error(msg=e) +# +# @router.get("/get_provider_by_name/{name}") +# def get_provider_by_name(name: str): +# try: +# res = ProviderService.get_provider_by_name(name) +# return R.success(data=res) +# except Exception as e: +# return R.error(msg=e) + + +@router.post("/update_provider") +def update_provider(data: ProviderUpdateRequest): + try: + if all( + field is None + for field in [data.name, data.api_key, data.base_url, data.logo, data.type,data.enabled] + ): + return R.error(msg='请至少填写一个参数') + + updated_provider =ProviderService.update_provider( + id=data.id, + data=dict(data) + ) + if updated_provider: + return R.success(msg='更新模型供应商成功', data=updated_provider) + else: + return R.error(msg='更新模型供应商失败') + except Exception as e: + print(e) + return R.error(msg=str(e)) + +@router.post('/connect_test') +def gpt_connect_test(data: TestRequest): + ModelService().connect_test(data.id, model=data.model) + return R.success(msg='连接成功') diff --git a/backend/app/routers/trend_subscription.py b/backend/app/routers/trend_subscription.py new file mode 100644 index 0000000000000000000000000000000000000000..bdff298b920fc9cb1ec31e4ef25e014f19437f73 --- /dev/null +++ b/backend/app/routers/trend_subscription.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +from typing import Optional + +from fastapi import APIRouter, Query +from pydantic import BaseModel + +from app.services.trend_subscription import TrendSubscriptionService +from app.utils.response import ResponseWrapper as R + +router = APIRouter() + + +class CreateSubscriptionRequest(BaseModel): + name: str + keywords: list[str] + platforms: Optional[list[str]] = None + match_mode: str = "any" + push_enabled: bool = False + push_channel_ids: Optional[list[int]] = None + + +class UpdateSubscriptionRequest(BaseModel): + name: Optional[str] = None + keywords: Optional[list[str]] = None + platforms: Optional[list[str]] = None + match_mode: Optional[str] = None + enabled: Optional[bool] = None + push_enabled: Optional[bool] = None + push_channel_ids: Optional[list[int]] = None + + +# ─── Collection-level routes (no path param) MUST come before parameterized routes ── + +@router.get("/trend_subscriptions") +def get_trend_subscriptions(): + return R.success(TrendSubscriptionService().list_subscriptions()) + + +@router.post("/trend_subscriptions") +def create_trend_subscription(data: CreateSubscriptionRequest): + return R.success( + TrendSubscriptionService().create_subscription( + name=data.name, + keywords=data.keywords, + platforms=data.platforms, + match_mode=data.match_mode, + push_enabled=data.push_enabled, + push_channel_ids=data.push_channel_ids, + ) + ) + + +@router.post("/trend_subscriptions/match_all") +def match_all_subscriptions(): + from app.services.scheduler import get_scheduler + + summary = get_scheduler().run_now() + return R.success(summary) + + +@router.get("/trend_matches") +def get_all_matches( + limit: int = Query(100, ge=1, le=500), + unread_only: bool = Query(False), +): + return R.success( + TrendSubscriptionService().list_matches(subscription_id=None, limit=limit, unread_only=unread_only) + ) + + +# ─── Parameterized routes (with {subscription_id}) ─────────────────────────────── + +@router.get("/trend_subscriptions/{subscription_id}") +def get_trend_subscription(subscription_id: int): + result = TrendSubscriptionService().get_subscription(subscription_id) + if result is None: + return R.error(msg=f"订阅 {subscription_id} 不存在", code=404) + return R.success(result) + + +@router.put("/trend_subscriptions/{subscription_id}") +def update_trend_subscription(subscription_id: int, data: UpdateSubscriptionRequest): + result = TrendSubscriptionService().update_subscription( + subscription_id=subscription_id, + name=data.name, + keywords=data.keywords, + platforms=data.platforms, + match_mode=data.match_mode, + enabled=data.enabled, + push_enabled=data.push_enabled, + push_channel_ids=data.push_channel_ids, + ) + if result is None: + return R.error(msg=f"订阅 {subscription_id} 不存在", code=404) + return R.success(result) + + +@router.delete("/trend_subscriptions/{subscription_id}") +def delete_trend_subscription(subscription_id: int): + ok = TrendSubscriptionService().delete_subscription(subscription_id) + if not ok: + return R.error(msg=f"订阅 {subscription_id} 不存在", code=404) + return R.success(msg="已删除") + + +@router.post("/trend_subscriptions/{subscription_id}/match") +def match_trend_subscription(subscription_id: int): + try: + result = TrendSubscriptionService().match_subscription(subscription_id) + # Also send push notifications if enabled and new matches found + if result["new_matches"] > 0: + try: + from app.services.notification import NotificationService + sub = TrendSubscriptionService().get_subscription(subscription_id) + if sub and sub.get("push_enabled") and sub.get("push_channel_ids"): + match_titles = [m["title"] for m in result["matches"]] + title = f"🔥 VideoMemo: {sub['name']} — {len(match_titles)} 条新热点" + body = "\n\n".join(f"• {t}" for t in match_titles[:10]) + if len(match_titles) > 10: + body += f"\n\n…共 {len(match_titles)} 条" + NotificationService().send_batch(sub["push_channel_ids"], title, body) + result["push_sent"] = True + except Exception: + result["push_sent"] = False + return R.success(result) + except ValueError as exc: + return R.error(msg=str(exc), code=404) + + +@router.get("/trend_subscriptions/{subscription_id}/matches") +def get_subscription_matches( + subscription_id: int, + limit: int = Query(100, ge=1, le=500), + unread_only: bool = Query(False), +): + return R.success( + TrendSubscriptionService().list_matches( + subscription_id=subscription_id, + limit=limit, + unread_only=unread_only, + ) + ) + + +@router.post("/trend_subscriptions/{subscription_id}/matches/read-all") +def mark_subscription_matches_read(subscription_id: int): + count = TrendSubscriptionService().mark_all_read(subscription_id) + return R.success({"marked_read": count}) diff --git a/backend/app/services/__init__.py b/backend/app/services/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backend/app/services/article.py b/backend/app/services/article.py new file mode 100644 index 0000000000000000000000000000000000000000..b06ed6ef17486dafea8c0fee8d0a4568dd9593fb --- /dev/null +++ b/backend/app/services/article.py @@ -0,0 +1,387 @@ +from __future__ import annotations + +import json +import os +import uuid +from pathlib import Path +from typing import Callable + +from app.article_fetchers.base import ArticleContent, ArticleFetcher +from app.article_fetchers.generic import GenericArticleFetcher +from app.article_fetchers.wechat import WechatArticleFetcher +from app.article_fetchers.xiaohongshu import XiaohongshuArticleFetcher +from app.db.article_dao import ( + create_subscription, + get_article_item, + get_subscription, + link_subscription_item, + list_article_items, + list_subscriptions, + mark_article_summarized, + update_subscription_refresh, + upsert_article_item, +) +from app.enmus.task_status_enums import TaskStatus +from app.gpt.gpt_factory import GPTFactory +from app.models.gpt_model import GPTSource +from app.models.model_config import ModelConfig +from app.models.transcriber_model import TranscriptSegment +from app.services.provider import ProviderService + + +def _note_output_dir() -> Path: + path = Path(os.getenv("NOTE_OUTPUT_DIR", "note_results")) + path.mkdir(parents=True, exist_ok=True) + return path + + +class ArticleService: + def __init__( + self, + fetchers: dict[str, ArticleFetcher] | None = None, + gpt_factory: Callable[[str, str], object] | None = None, + ): + self.fetchers = fetchers or { + "wechat_mp": WechatArticleFetcher(), + "xiaohongshu": XiaohongshuArticleFetcher(), + "generic_web": GenericArticleFetcher(), + } + self.gpt_factory = gpt_factory or self._create_gpt + + def generate_from_url( + self, + url: str, + platform: str, + provider_id: str, + model_name: str, + style: str = "", + extras: str = "", + task_id: str | None = None, + ) -> dict: + task_id = task_id or str(uuid.uuid4()) + try: + self._update_status(task_id, TaskStatus.PARSING) + article = self._fetcher(platform).fetch(url) + item = upsert_article_item(article) + self._update_status(task_id, TaskStatus.TRANSCRIBING) + + gpt = self.gpt_factory(model_name, provider_id) + markdown = gpt.summarize( + GPTSource( + segment=self._segments(article), + title=article.title, + tags="article", + style=style, + extras=extras, + ) + ) + + self._update_status(task_id, TaskStatus.SAVING) + self._write_note_json( + task_id, + article, + markdown, + int(getattr(gpt, "total_tokens", 0) or 0), + ) + mark_article_summarized(item.id, task_id) + self._update_status(task_id, TaskStatus.SUCCESS) + self._index_task(task_id) + return {"task_id": task_id, "article_item_id": item.id} + except Exception: + self._update_status(task_id, TaskStatus.FAILED) + raise + + def generate_from_content( + self, + url: str, + platform: str, + title: str, + content_text: str, + provider_id: str, + model_name: str, + style: str = "", + extras: str = "", + author_name: str = "", + task_id: str | None = None, + ) -> dict: + body = (content_text or "").strip() + if len(body) < 20: + raise ValueError("导入正文过短,无法生成总结") + task_id = task_id or str(uuid.uuid4()) + try: + self._update_status(task_id, TaskStatus.PARSING) + article = ArticleContent( + platform=platform or "generic_web", + url=url or f"manual://{task_id}", + article_id=url or task_id, + title=(title or "").strip() or "导入文章", + author_name=author_name, + content_text=body, + raw_metadata={"source": "manual_import"}, + ) + item = upsert_article_item(article) + self._update_status(task_id, TaskStatus.TRANSCRIBING) + + gpt = self.gpt_factory(model_name, provider_id) + markdown = gpt.summarize( + GPTSource( + segment=self._segments(article), + title=article.title, + tags="article", + style=style, + extras=extras, + ) + ) + + self._update_status(task_id, TaskStatus.SAVING) + self._write_note_json( + task_id, + article, + markdown, + int(getattr(gpt, "total_tokens", 0) or 0), + ) + mark_article_summarized(item.id, task_id) + self._update_status(task_id, TaskStatus.SUCCESS) + self._index_task(task_id) + return {"task_id": task_id, "article_item_id": item.id} + except Exception: + self._update_status(task_id, TaskStatus.FAILED) + raise + + def fetch_only_from_url(self, url: str, platform: str) -> dict: + article = self._fetcher(platform).fetch(url) + item = upsert_article_item(article) + return self._item_payload(item, include_content=True) + + def import_only_content( + self, + url: str, + platform: str, + title: str, + content_text: str, + author_name: str = "", + ) -> dict: + body = (content_text or "").strip() + if len(body) < 20: + raise ValueError("导入正文过短") + article_id = url or str(uuid.uuid4()) + article = ArticleContent( + platform=platform or "generic_web", + url=url or f"manual://{article_id}", + article_id=article_id, + title=(title or "").strip() or "导入文章", + author_name=author_name, + content_text=body, + raw_metadata={"source": "manual_import"}, + ) + item = upsert_article_item(article) + return self._item_payload(item, include_content=True) + + def search(self, platform: str, keyword: str, limit: int = 20) -> dict: + articles = self._fetcher(platform).search(keyword, limit) + items = [upsert_article_item(article) for article in articles] + return { + "platform": platform, + "keyword": keyword, + "status": "ok", + "message": "", + "items": [self._item_payload(item) for item in items], + } + + def refresh_subscription(self, subscription_id: int, limit: int = 20) -> dict: + subscription = get_subscription(subscription_id) + if not subscription: + raise ValueError("订阅不存在") + + fetcher = self._fetcher(subscription.platform) + if subscription.type == "publisher": + articles = fetcher.fetch_publisher(subscription.query, limit) + reason = f"publisher:{subscription.query}" + else: + articles = fetcher.search(subscription.query, limit) + reason = f"keyword:{subscription.query}" + + items = [] + for article in articles: + item = upsert_article_item(article) + link_subscription_item(subscription.id, item.id, reason) + items.append(item) + update_subscription_refresh(subscription.id) + return { + "subscription_id": subscription.id, + "count": len(items), + "items": [self._item_payload(item) for item in items], + } + + def summarize_item( + self, + item_id: int, + provider_id: str, + model_name: str, + style: str = "", + extras: str = "", + ) -> dict: + item = get_article_item(item_id) + if not item: + raise ValueError("文章不存在") + if item.task_id and item.summary_status == "summarized": + return {"task_id": item.task_id, "article_item_id": item.id} + return self.generate_from_url( + url=item.url, + platform=item.platform, + provider_id=provider_id, + model_name=model_name, + style=style, + extras=extras, + ) + + def list_items(self, subscription_id: int | None = None) -> list[dict]: + return [self._item_payload(item) for item in list_article_items(subscription_id)] + + def get_item(self, item_id: int) -> dict: + item = get_article_item(item_id) + if not item: + raise ValueError("文章不存在") + return self._item_payload(item, include_content=True) + + def create_subscription( + self, + platform: str, + subscription_type: str, + query: str, + label: str = "", + ) -> dict: + subscription = create_subscription(platform, subscription_type, query, label) + return self._subscription_payload(subscription) + + def list_subscriptions(self) -> list[dict]: + return [self._subscription_payload(item) for item in list_subscriptions()] + + def _fetcher(self, platform: str) -> ArticleFetcher: + if platform not in self.fetchers: + raise ValueError(f"不支持的文章平台:{platform}") + return self.fetchers[platform] + + def _item_payload(self, item, include_content: bool = False) -> dict: + payload = { + "id": item.id, + "platform": item.platform, + "title": item.title, + "url": item.url, + "author_name": item.author_name, + "author_id": item.author_id, + "cover_url": item.cover_url, + "published_at": item.published_at, + "summary_status": item.summary_status, + "task_id": item.task_id, + } + if include_content: + payload["content_text"] = (getattr(item, "content_text", "") or "").strip() + if not payload["content_text"] and item.task_id: + payload["content_text"] = self._content_from_note_result(item.task_id) + return payload + + def _content_from_note_result(self, task_id: str) -> str: + if not task_id: + return "" + result_path = _note_output_dir() / f"{task_id}.json" + if not result_path.exists(): + return "" + try: + payload = json.loads(result_path.read_text(encoding="utf-8")) + except Exception: + return "" + transcript = payload.get("transcript") or {} + return str(transcript.get("full_text") or "").strip() + + def _subscription_payload(self, item) -> dict: + return { + "id": item.id, + "platform": item.platform, + "type": item.type, + "query": item.query, + "label": item.label, + "enabled": item.enabled, + "last_error": item.last_error, + } + + def _create_gpt(self, model_name: str, provider_id: str): + provider = ProviderService.get_provider_by_id(provider_id) + if not provider: + raise ValueError("请选择模型和提供者") + return GPTFactory().from_config( + ModelConfig( + api_key=provider["api_key"], + base_url=provider["base_url"], + model_name=model_name, + provider=provider["type"], + name=provider["name"], + ) + ) + + def _segments(self, article: ArticleContent) -> list[TranscriptSegment]: + paragraphs = [p.strip() for p in article.content_text.splitlines() if p.strip()] + if not paragraphs and article.content_text.strip(): + paragraphs = [article.content_text.strip()] + return [ + TranscriptSegment(start=float(index), end=float(index + 1), text=text) + for index, text in enumerate(paragraphs) + ] + + def _write_note_json( + self, + task_id: str, + article: ArticleContent, + markdown: str, + total_tokens: int, + ) -> None: + segments = self._segments(article) + payload = { + "markdown": markdown, + "transcript": { + "language": "zh", + "full_text": article.content_text, + "segments": [ + {"start": segment.start, "end": segment.end, "text": segment.text} + for segment in segments + ], + }, + "audio_meta": { + "file_path": "", + "title": article.title, + "duration": 0, + "cover_url": article.cover_url, + "platform": article.platform, + "video_id": article.article_id, + "raw_info": { + "source_type": "article", + "url": article.url, + "author_name": article.author_name, + "author_id": article.author_id, + "published_at": article.published_at, + "image_urls": article.image_urls, + **(article.raw_metadata or {}), + }, + "video_path": None, + }, + "total_tokens": total_tokens, + } + (_note_output_dir() / f"{task_id}.json").write_text( + json.dumps(payload, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + + def _update_status(self, task_id: str, status: TaskStatus) -> None: + payload = {"status": status.value, "paused": False} + (_note_output_dir() / f"{task_id}.status.json").write_text( + json.dumps(payload, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + + def _index_task(self, task_id: str) -> None: + try: + from app.services.vector_store import VectorStoreManager + + VectorStoreManager().index_task(task_id) + except Exception: + pass diff --git a/backend/app/services/browser_cookie.py b/backend/app/services/browser_cookie.py new file mode 100644 index 0000000000000000000000000000000000000000..67f27fccc3c720a8144a5aeef943603a70f03e4b --- /dev/null +++ b/backend/app/services/browser_cookie.py @@ -0,0 +1,135 @@ +import subprocess +import sys +import webbrowser +from typing import Optional + +from app.services.cookie_manager import CookieConfigManager + + +PLATFORM_COOKIE_DOMAINS = { + "bilibili": ("bilibili.com",), + "youtube": ("youtube.com", "youtu.be"), + "douyin": ("douyin.com", "iesdouyin.com"), + "kuaishou": ("kuaishou.com",), + "xiaohongshu": ("xiaohongshu.com", "xhslink.com"), +} + +PLATFORM_LOGIN_URLS = { + "bilibili": "https://www.bilibili.com/", + "youtube": "https://www.youtube.com/", + "douyin": "https://www.douyin.com/", + "kuaishou": "https://www.kuaishou.com/", + "xiaohongshu": "https://www.xiaohongshu.com/", +} + +PLATFORM_LABELS = { + "bilibili": "B站", + "youtube": "YouTube", + "douyin": "抖音", + "kuaishou": "快手", + "xiaohongshu": "小红书", +} + + +class BrowserCookieError(Exception): + pass + + +def _extract_cookies_from_browser(browser: str): + try: + from yt_dlp.cookies import extract_cookies_from_browser + except Exception as exc: + raise BrowserCookieError("当前后端环境缺少 yt-dlp,无法从浏览器读取 Cookie") from exc + return extract_cookies_from_browser(browser) + + +def _cookie_domains_for_platform(platform: str) -> tuple[str, ...]: + return PLATFORM_COOKIE_DOMAINS.get(platform, (platform,)) + + +def _matches_platform_domain(cookie_domain: str, platform: str) -> bool: + domain = (cookie_domain or "").lstrip(".").lower() + return any( + domain == target or domain.endswith(f".{target}") + for target in _cookie_domains_for_platform(platform) + ) + + +def _format_cookie_pairs(cookies, platform: str) -> list[str]: + pairs = [] + seen = set() + for cookie in cookies: + name = getattr(cookie, "name", "") + value = getattr(cookie, "value", "") + domain = getattr(cookie, "domain", "") + if not name or not value or not _matches_platform_domain(domain, platform): + continue + key = (name, value) + if key in seen: + continue + seen.add(key) + pairs.append(f"{name}={value}") + return pairs + + +def _open_url_in_browser(url: str, browser: str) -> bool: + browser = (browser or "").strip().lower() + if sys.platform == "darwin" and browser: + mac_apps = { + "chrome": "Google Chrome", + "edge": "Microsoft Edge", + "firefox": "Firefox", + "safari": "Safari", + "brave": "Brave Browser", + "chromium": "Chromium", + "opera": "Opera", + "vivaldi": "Vivaldi", + } + app_name = mac_apps.get(browser) + if app_name: + try: + subprocess.Popen(["open", "-a", app_name, url]) + return True + except Exception: + pass + return webbrowser.open_new_tab(url) + + +def sync_browser_cookie( + platform: str, + browser: str, + manager: Optional[CookieConfigManager] = None, +) -> dict: + platform = (platform or "").strip() + browser = (browser or "").strip() + if not platform: + raise BrowserCookieError("平台不能为空") + if not browser: + raise BrowserCookieError("请选择浏览器") + + try: + cookies = _extract_cookies_from_browser(browser) + except Exception as exc: + if isinstance(exc, BrowserCookieError): + raise + raise BrowserCookieError(f"从浏览器读取 Cookie 失败:{exc}") from exc + + pairs = _format_cookie_pairs(cookies, platform) + if not pairs: + login_url = PLATFORM_LOGIN_URLS.get(platform) + opened = bool(login_url and _open_url_in_browser(login_url, browser)) + label = PLATFORM_LABELS.get(platform, platform) + opened_hint = f"已打开{label}页面,登录后再点击一键获取。" if opened else "" + raise BrowserCookieError( + f"未找到 {platform} 对应的浏览器 Cookie,请先在该浏览器登录对应平台。{opened_hint}" + ) + + cookie_str = "; ".join(pairs) + cookie_manager = manager or CookieConfigManager() + cookie_manager.set(platform, cookie_str, browser=browser) + return { + "platform": platform, + "browser": browser, + "cookie": cookie_str, + "count": len(pairs), + } diff --git a/backend/app/services/chat_service.py b/backend/app/services/chat_service.py new file mode 100644 index 0000000000000000000000000000000000000000..92c01ca911cb47e84eccee2a52059ac00480eb77 --- /dev/null +++ b/backend/app/services/chat_service.py @@ -0,0 +1,357 @@ +import json +import os +from typing import Optional + +from app.gpt.gpt_factory import GPTFactory +from app.gpt.utils import strip_think_blocks +from app.models.model_config import ModelConfig +from app.services.provider import ProviderService +from app.services.vector_store import VectorStoreManager, NOTE_OUTPUT_DIR +from app.services.chat_tools import TOOLS, execute_tool +from app.utils.logger import get_logger + +logger = get_logger(__name__) + + +def _load_task_brief(task_id: str) -> dict: + """读出某篇笔记的标题/平台/URL,用于源卡片展示。失败返回空 dict。""" + path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json") + if not os.path.exists(path): + return {} + try: + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + except Exception: + return {} + am = data.get("audio_meta", {}) or {} + raw = am.get("raw_info", {}) or {} + return { + "title": am.get("title") or raw.get("title") or "(无标题)", + "platform": am.get("platform") or "", + "url": raw.get("webpage_url") or "", + "uploader": raw.get("uploader") or "", + } + +SYSTEM_PROMPT = """你是一个视频笔记问答助手。你拥有以下能力: + +1. 系统已自动检索了一些相关内容作为初始参考(见下方) +2. 你可以调用工具主动查询更多信息: + - lookup_transcript: 查询视频原始转录文本(支持按时间、关键词、位置筛选) + - get_video_info: 获取视频元信息(标题、作者、简介、标签等) + - get_note_content: 获取完整笔记内容 + +--- 初始检索内容 --- +{context} +--- + +回答要求: +- 如果初始检索内容不足以回答问题,请主动调用工具获取更多信息 +- 回答关于视频具体原话、细节时,用 lookup_transcript 查询原文 +- 回答关于作者、标题等基本信息时,用 get_video_info 查询 +- 请用中文回答,保持简洁准确""" + + +def _build_context(chunks: list[dict]) -> str: + """将检索到的片段拼接为上下文文本。""" + parts = [] + for chunk in chunks: + meta = chunk.get("metadata", {}) + source_type = meta.get("source_type", "unknown") + if source_type == "meta": + label = "[视频信息]" + elif source_type == "markdown": + label = f"[笔记 - {meta.get('section_title', '')}]" + else: + start = meta.get("start_time", 0) + end = meta.get("end_time", 0) + label = f"[转录 - {start:.0f}s~{end:.0f}s]" + parts.append(f"{label}\n{chunk['text']}") + return "\n\n".join(parts) + + +def _build_sources(chunks: list[dict]) -> list[dict]: + """从检索片段中提取来源信息。""" + sources = [] + for chunk in chunks: + meta = chunk.get("metadata", {}) + source = { + "text": chunk["text"][:200], + "source_type": meta.get("source_type", "unknown"), + } + if meta.get("section_title"): + source["section_title"] = meta["section_title"] + if meta.get("start_time") is not None: + source["start_time"] = meta["start_time"] + if meta.get("end_time") is not None: + source["end_time"] = meta["end_time"] + sources.append(source) + return sources + + +def chat( + task_id: str, + question: str, + history: list[dict], + provider_id: str, + model_name: str, +) -> dict: + """ + RAG + Tool Calling 问答。 + 1. 向量检索初始上下文 + 2. 调用 LLM(带 tools) + 3. 如果 LLM 调用了工具,执行工具并将结果返回给 LLM + 4. 循环直到 LLM 给出最终回答 + """ + vector_store = VectorStoreManager() + + # 1. 检索初始上下文 + chunks = vector_store.query(task_id, question, n_results=6) + context = _build_context(chunks) if chunks else "(未检索到相关内容,请使用工具查询)" + sources = _build_sources(chunks) if chunks else [] + + # 2. 构建消息 + system_msg = SYSTEM_PROMPT.format(context=context) + messages = [{"role": "system", "content": system_msg}] + + for msg in history[-20:]: + messages.append({"role": msg["role"], "content": msg["content"]}) + + messages.append({"role": "user", "content": question}) + + # 3. 获取 LLM client + provider = ProviderService.get_provider_by_id(provider_id) + if not provider: + raise ValueError(f"未找到模型供应商: {provider_id}") + + config = ModelConfig( + api_key=provider["api_key"], + base_url=provider["base_url"], + model_name=model_name, + provider=provider["type"], + name=provider["name"], + ) + gpt = GPTFactory.from_config(config) + + logger.info(f"Chat: task_id={task_id}, model={model_name}") + + # 4. Tool calling 循环(最多 3 轮) + max_rounds = 3 + for round_i in range(max_rounds): + response = gpt.client.chat.completions.create( + model=gpt.model, + messages=messages, + tools=TOOLS, + temperature=0.7, + ) + + msg = response.choices[0].message + + # 没有工具调用,直接返回 + if not msg.tool_calls: + return {"answer": msg.content or "", "sources": sources} + + # 处理工具调用 + messages.append(msg) + + for tool_call in msg.tool_calls: + fn_name = tool_call.function.name + try: + fn_args = json.loads(tool_call.function.arguments) + except json.JSONDecodeError: + fn_args = {} + + logger.info(f"Tool call [{round_i+1}/{max_rounds}]: {fn_name}({fn_args})") + + result = execute_tool(fn_name, fn_args, default_task_id=task_id) + + messages.append({ + "role": "tool", + "tool_call_id": tool_call.id, + "content": result, + }) + + # 超过最大轮次,做最后一次不带 tools 的调用 + response = gpt.client.chat.completions.create( + model=gpt.model, + messages=messages, + temperature=0.7, + ) + + return {"answer": strip_think_blocks(response.choices[0].message.content), "sources": sources} + + +# ── 跨笔记知识库问答 ───────────────────────────────────────── + +ACROSS_SYSTEM_PROMPT = """你是一个跨视频笔记的知识库问答助手,可以同时基于多篇笔记回答问题。 + +工作方式: +1. 系统已经从知识库里检索到了若干个最相关的片段(见下方「初始检索内容」),每段都标注了它来自哪篇笔记 +2. 如果初始片段不足,你可以调用工具针对**指定 task_id**的笔记深挖: + - lookup_transcript(task_id, ...): 查询该笔记的转录文本 + - get_video_info(task_id): 获取该笔记的视频元信息 + - get_note_content(task_id): 获取该笔记的完整 Markdown + +--- 初始检索内容 --- +{context} +--- + +回答要求: +- 综合多篇笔记的信息作答,遇到不同观点要明确指出来自哪篇 +- 在正文里引用具体内容时,用《笔记标题》的形式标明出处 +- 如果检索结果只跟一两篇笔记相关,回答时不要硬凑其它笔记 +- 用中文回答,保持简洁准确""" + + +def _build_across_context(chunks: list[dict], briefs: dict[str, dict]) -> str: + """跨笔记 context:每段都标注来源笔记标题 + task_id。""" + parts = [] + for chunk in chunks: + meta = chunk.get("metadata", {}) + tid = chunk.get("task_id", "") + brief = briefs.get(tid, {}) + title = brief.get("title", "(无标题)") + source_type = meta.get("source_type", "unknown") + if source_type == "meta": + label = f"[来源:《{title}》· 视频信息 · task_id={tid}]" + elif source_type == "markdown": + label = f"[来源:《{title}》· 笔记 - {meta.get('section_title', '')} · task_id={tid}]" + else: + start = meta.get("start_time", 0) + end = meta.get("end_time", 0) + label = f"[来源:《{title}》· 转录 {start:.0f}s~{end:.0f}s · task_id={tid}]" + parts.append(f"{label}\n{chunk['text']}") + return "\n\n".join(parts) + + +def _build_across_sources(chunks: list[dict], briefs: dict[str, dict]) -> list[dict]: + """跨笔记 sources:每条带 task_id + 标题 + 平台 + URL,方便前端做引用卡片+跳转。""" + sources = [] + for chunk in chunks: + meta = chunk.get("metadata", {}) + tid = chunk.get("task_id", "") + brief = briefs.get(tid, {}) + source = { + "task_id": tid, + "title": brief.get("title", "(无标题)"), + "platform": brief.get("platform", ""), + "url": brief.get("url", ""), + "uploader": brief.get("uploader", ""), + "text": chunk["text"][:200], + "source_type": meta.get("source_type", "unknown"), + } + if meta.get("section_title"): + source["section_title"] = meta["section_title"] + if meta.get("start_time") is not None: + source["start_time"] = meta["start_time"] + if meta.get("end_time") is not None: + source["end_time"] = meta["end_time"] + sources.append(source) + return sources + + +def chat_across( + question: str, + history: list[dict], + scope: dict, + provider_id: str, + model_name: str, +) -> dict: + """ + 跨笔记知识库问答。 + scope: {"task_ids": [...] | None} + task_ids=None 或缺省 → 全库 + task_ids=[] → 视为没匹配到任何笔记 + task_ids=[...] → 只在这些笔记里检索 + """ + vector_store = VectorStoreManager() + + task_ids = scope.get("task_ids") if scope else None + # None = 全库;空数组 = 用户筛了但没选中任何笔记,直接告知 + if task_ids is not None and len(task_ids) == 0: + return { + "answer": "当前过滤条件下没有可检索的笔记。请放宽过滤条件后再试。", + "sources": [], + } + + # 1. 跨 collection 检索 + chunks = vector_store.query_across( + query_text=question, + task_ids=task_ids, + n_results_per_task=3, + max_total=12, + ) + + if not chunks: + return { + "answer": "知识库里还没有任何索引内容。请先生成几篇笔记后再来提问。" + if not vector_store.list_indexed_task_ids() + else "未检索到与问题相关的内容。可以试试换种问法,或放宽过滤条件。", + "sources": [], + } + + # 2. 命中的笔记 brief(标题/平台/URL) + hit_task_ids = list({c["task_id"] for c in chunks if c.get("task_id")}) + briefs = {tid: _load_task_brief(tid) for tid in hit_task_ids} + + context = _build_across_context(chunks, briefs) + sources = _build_across_sources(chunks, briefs) + + # 3. 构建消息 + system_msg = ACROSS_SYSTEM_PROMPT.format(context=context) + messages = [{"role": "system", "content": system_msg}] + for msg in history[-20:]: + messages.append({"role": msg["role"], "content": msg["content"]}) + messages.append({"role": "user", "content": question}) + + # 4. 获取 LLM client + provider = ProviderService.get_provider_by_id(provider_id) + if not provider: + raise ValueError(f"未找到模型供应商: {provider_id}") + config = ModelConfig( + api_key=provider["api_key"], + base_url=provider["base_url"], + model_name=model_name, + provider=provider["type"], + name=provider["name"], + ) + gpt = GPTFactory.from_config(config) + + logger.info(f"ChatAcross: hit_tasks={len(hit_task_ids)}, chunks={len(chunks)}, model={model_name}") + + # 5. Tool calling 循环(最多 3 轮)—— 跨笔记场景不传 default_task_id,强制模型在 arguments 里指定 + max_rounds = 3 + for round_i in range(max_rounds): + response = gpt.client.chat.completions.create( + model=gpt.model, + messages=messages, + tools=TOOLS, + temperature=0.7, + ) + msg = response.choices[0].message + + if not msg.tool_calls: + return {"answer": msg.content or "", "sources": sources} + + messages.append(msg) + for tool_call in msg.tool_calls: + fn_name = tool_call.function.name + try: + fn_args = json.loads(tool_call.function.arguments) + except json.JSONDecodeError: + fn_args = {} + + logger.info(f"AcrossTool [{round_i+1}/{max_rounds}]: {fn_name}({fn_args})") + result = execute_tool(fn_name, fn_args) # 跨笔记:必须由 args 提供 task_id + messages.append({ + "role": "tool", + "tool_call_id": tool_call.id, + "content": result, + }) + + # 超过最大轮次,最后一次不带 tools + response = gpt.client.chat.completions.create( + model=gpt.model, + messages=messages, + temperature=0.7, + ) + return {"answer": strip_think_blocks(response.choices[0].message.content), "sources": sources} diff --git a/backend/app/services/chat_tools.py b/backend/app/services/chat_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..2db298ecabff54c4f80e2d613269bfd6790bfe8b --- /dev/null +++ b/backend/app/services/chat_tools.py @@ -0,0 +1,207 @@ +""" +Chat function calling 工具定义与执行。 +提供给 LLM 调用,用于主动查询视频原文、笔记、元信息。 +""" + +import json +import os +from typing import Optional + +from app.utils.logger import get_logger + +logger = get_logger(__name__) + +NOTE_OUTPUT_DIR = os.getenv("NOTE_OUTPUT_DIR", "note_results") + + +def _load_note_data(task_id: str) -> Optional[dict]: + path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json") + if not os.path.exists(path): + return None + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + + +# ── 工具定义(OpenAI function calling 格式)────────────────────── + +_TASK_ID_PARAM = { + "type": "string", + "description": "目标笔记的 task_id。跨笔记知识库问答时必填,单笔记会话可省略(系统会自动填入当前笔记)。", +} + + +TOOLS = [ + { + "type": "function", + "function": { + "name": "lookup_transcript", + "description": "查询某篇笔记对应视频的原始转录文本。可按时间范围筛选、按关键词搜索、或获取指定位置的内容。", + "parameters": { + "type": "object", + "properties": { + "task_id": _TASK_ID_PARAM, + "start_time": { + "type": "number", + "description": "起始时间(秒),例如 0 表示视频开头,60 表示第1分钟", + }, + "end_time": { + "type": "number", + "description": "结束时间(秒),不传则到末尾", + }, + "keyword": { + "type": "string", + "description": "搜索关键词,返回包含该关键词的转录片段", + }, + "position": { + "type": "string", + "enum": ["start", "end"], + "description": "快捷位置:start=视频开头前30句,end=视频结尾后30句", + }, + }, + "required": [], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_video_info", + "description": "获取某篇笔记对应视频的完整元信息,包括标题、作者、简介、标签、时长、播放量等。", + "parameters": { + "type": "object", + "properties": { + "task_id": _TASK_ID_PARAM, + }, + "required": [], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_note_content", + "description": "获取某篇笔记 AI 生成的完整内容(Markdown 格式)。", + "parameters": { + "type": "object", + "properties": { + "task_id": _TASK_ID_PARAM, + }, + "required": [], + }, + }, + }, +] + + +# ── 工具执行 ────────────────────────────────────────────────── + +def execute_tool( + tool_name: str, + arguments: dict, + default_task_id: Optional[str] = None, +) -> str: + """ + 执行工具调用,返回结果字符串。 + - tool 的 arguments 里若带 task_id 就用它(跨笔记场景) + - 否则 fallback 到 default_task_id(单笔记场景由上游注入) + """ + task_id = arguments.get("task_id") or default_task_id + if not task_id: + return json.dumps({"error": "缺少 task_id"}, ensure_ascii=False) + + data = _load_note_data(task_id) + if not data: + return json.dumps({"error": f"笔记数据不存在: {task_id}"}, ensure_ascii=False) + + if tool_name == "lookup_transcript": + return _lookup_transcript(data, arguments) + elif tool_name == "get_video_info": + return _get_video_info(data) + elif tool_name == "get_note_content": + return _get_note_content(data) + else: + return json.dumps({"error": f"未知工具: {tool_name}"}, ensure_ascii=False) + + +def _lookup_transcript(data: dict, args: dict) -> str: + segments = data.get("transcript", {}).get("segments", []) + if not segments: + return json.dumps({"error": "没有转录数据"}, ensure_ascii=False) + + position = args.get("position") + start_time = args.get("start_time") + end_time = args.get("end_time") + keyword = args.get("keyword", "").strip() + + # 快捷位置 + if position == "start": + filtered = segments[:30] + elif position == "end": + filtered = segments[-30:] + else: + filtered = segments + + # 时间筛选 + if start_time is not None: + filtered = [s for s in filtered if s.get("end", 0) >= start_time] + if end_time is not None: + filtered = [s for s in filtered if s.get("start", 0) <= end_time] + + # 关键词筛选 + if keyword: + filtered = [s for s in filtered if keyword.lower() in s.get("text", "").lower()] + + # 限制返回量,避免 token 爆炸 + if len(filtered) > 50: + filtered = filtered[:50] + truncated = True + else: + truncated = False + + result = { + "total_segments": len(data.get("transcript", {}).get("segments", [])), + "returned": len(filtered), + "truncated": truncated, + "segments": [ + { + "start": round(s.get("start", 0), 1), + "end": round(s.get("end", 0), 1), + "text": s.get("text", ""), + } + for s in filtered + ], + } + return json.dumps(result, ensure_ascii=False) + + +def _get_video_info(data: dict) -> str: + am = data.get("audio_meta", {}) + raw = am.get("raw_info", {}) or {} + + info = { + "title": am.get("title") or raw.get("title", ""), + "uploader": raw.get("uploader", ""), + "description": raw.get("description", "")[:1000], + "tags": raw.get("tags", [])[:20] if isinstance(raw.get("tags"), list) else [], + "duration_seconds": am.get("duration", 0), + "platform": am.get("platform", ""), + "video_id": am.get("video_id", ""), + "url": raw.get("webpage_url", ""), + "view_count": raw.get("view_count"), + "like_count": raw.get("like_count"), + "comment_count": raw.get("comment_count"), + } + # 去除 None 值 + info = {k: v for k, v in info.items() if v is not None and v != ""} + return json.dumps(info, ensure_ascii=False) + + +def _get_note_content(data: dict) -> str: + md = data.get("markdown", "") + if isinstance(md, list): + # 多版本,取最新 + md = md[-1].get("content", "") if md else "" + # 限制长度 + if len(md) > 5000: + md = md[:5000] + "\n\n... (内容过长已截断)" + return json.dumps({"markdown": md}, ensure_ascii=False) diff --git a/backend/app/services/constant.py b/backend/app/services/constant.py new file mode 100644 index 0000000000000000000000000000000000000000..9da0b8679f7a8519490629f30716c0105a52da45 --- /dev/null +++ b/backend/app/services/constant.py @@ -0,0 +1,16 @@ +from app.downloaders.bilibili_downloader import BilibiliDownloader +from app.downloaders.douyin_downloader import DouyinDownloader +from app.downloaders.kuaishou_downloader import KuaiShouDownloader +from app.downloaders.local_downloader import LocalDownloader +from app.downloaders.xiaohongshu_downloader import XiaohongshuDownloader +from app.downloaders.youtube_downloader import YoutubeDownloader + +SUPPORT_PLATFORM_MAP = { + 'youtube':YoutubeDownloader(), + 'bilibili':BilibiliDownloader(), + 'tiktok':DouyinDownloader(), + 'kuaishou':KuaiShouDownloader(), + 'douyin':DouyinDownloader(), + 'xiaohongshu':XiaohongshuDownloader(), + 'local':LocalDownloader() +} \ No newline at end of file diff --git a/backend/app/services/cookie_manager.py b/backend/app/services/cookie_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..94f7a902db2e92b6bd3e6cf220d236977700192a --- /dev/null +++ b/backend/app/services/cookie_manager.py @@ -0,0 +1,61 @@ +import json +from pathlib import Path +from typing import Optional, Dict + + +class CookieConfigManager: + def __init__(self, filepath: str = "config/downloader.json"): + self.path = Path(filepath) + self.path.parent.mkdir(parents=True, exist_ok=True) + if not self.path.exists(): + self._write({}) + + def _read(self) -> Dict[str, Dict[str, str]]: + try: + with self.path.open("r", encoding="utf-8") as f: + return json.load(f) + except Exception: + return {} + + def _write(self, data: Dict[str, Dict[str, str]]): + with self.path.open("w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + def get(self, platform: str) -> Optional[str]: + data = self._read() + return data.get(platform, {}).get("cookie") + + def get_browser(self, platform: str) -> Optional[str]: + """读取该平台配置的「从浏览器读 cookie」选项,未配置返回 None。""" + data = self._read() + browser = data.get(platform, {}).get("browser") + return browser or None + + def set(self, platform: str, cookie: str, browser: Optional[str] = None): + """保存平台的 cookie 字符串及可选的浏览器名。 + + browser 传 None 表示不修改原浏览器设置;传空字符串则清除浏览器设置。 + """ + data = self._read() + entry = data.get(platform, {}) or {} + entry["cookie"] = cookie + if browser is not None: + if browser: + entry["browser"] = browser + else: + entry.pop("browser", None) + data[platform] = entry + self._write(data) + + def delete(self, platform: str): + data = self._read() + if platform in data: + del data[platform] + self._write(data) + + def list_all(self) -> Dict[str, str]: + data = self._read() + return {k: v.get("cookie", "") for k, v in data.items()} + + def exists(self, platform: str) -> bool: + return self.get(platform) is not None \ No newline at end of file diff --git a/backend/app/services/custom_platform_manager.py b/backend/app/services/custom_platform_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..b4ec3185a428c4a12ee1380937460b51b7136ecb --- /dev/null +++ b/backend/app/services/custom_platform_manager.py @@ -0,0 +1,92 @@ +"""自定义平台管理:用户在 UI 里登记的额外平台条目。 + +每条记录形如 { key, name, match }: + - key: 平台唯一标识,作为 NoteGenerator._get_downloader 的 platform 入参; + Cookie 也用同样的 key 存到现有 CookieConfigManager(无需新存储)。 + - name: 展示名。 + - match: URL 子串匹配。如 "vimeo.com"。命中即视为该平台。 +""" +import json +import re +from pathlib import Path +from typing import Optional + + +_PATH = Path("config/custom_platforms.json") +_KEY_RE = re.compile(r"^[a-z0-9][a-z0-9_-]{1,31}$") + + +def _read() -> list[dict]: + if not _PATH.exists(): + return [] + try: + data = json.loads(_PATH.read_text(encoding="utf-8")) + if isinstance(data, dict): + data = data.get("platforms", []) + return [p for p in (data or []) if isinstance(p, dict) and p.get("key")] + except Exception: + return [] + + +def _write(items: list[dict]) -> None: + _PATH.parent.mkdir(parents=True, exist_ok=True) + _PATH.write_text( + json.dumps({"platforms": items}, ensure_ascii=False, indent=2), + encoding="utf-8", + ) + + +def list_all() -> list[dict]: + return _read() + + +def get(key: str) -> Optional[dict]: + for p in _read(): + if p.get("key") == key: + return p + return None + + +def upsert(key: str, name: str, match: str) -> dict: + """创建或更新自定义平台。key 不可改(作为身份),name/match 可改。""" + key = (key or "").strip().lower() + if not _KEY_RE.match(key): + raise ValueError("平台标识只能是 2~32 位的小写字母、数字、下划线或短横线") + if key in {"youtube", "bilibili", "douyin", "kuaishou", "xiaohongshu", "tiktok", "local"}: + raise ValueError(f"标识 {key!r} 与内建平台冲突") + name = (name or "").strip() or key + match = (match or "").strip() + if not match: + raise ValueError("URL 匹配规则不能为空") + + items = _read() + found = False + for p in items: + if p["key"] == key: + p["name"], p["match"] = name, match + found = True + break + if not found: + items.append({"key": key, "name": name, "match": match}) + _write(items) + return next(p for p in items if p["key"] == key) + + +def delete(key: str) -> bool: + items = _read() + new_items = [p for p in items if p.get("key") != key] + if len(new_items) == len(items): + return False + _write(new_items) + return True + + +def match_custom_platform(url: str) -> Optional[dict]: + """URL → 自定义平台条目。返回首个 match 命中的项。""" + if not url: + return None + for p in _read(): + m = (p.get("match") or "").strip() + if m and m in url: + return p + return None diff --git a/backend/app/services/feishu_cli_service.py b/backend/app/services/feishu_cli_service.py new file mode 100644 index 0000000000000000000000000000000000000000..41a77f42a9dbb2c2aff8499d9ed8635fdd620713 --- /dev/null +++ b/backend/app/services/feishu_cli_service.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +import json +import logging +import os +import re +import shutil +import subprocess +from typing import Any, Dict, List, Optional + +from app.services.feishu_config_manager import FeishuConfigManager +from app.services.feishu_service import FeishuError, FeishuService + +logger = logging.getLogger(__name__) + +CLI_TIMEOUT = 120 # seconds,导入/创建文档可能稍慢 + + +class FeishuCliService: + """通过官方 lark CLI(npm 包 @larksuite/cli,二进制名 lark-cli)推送笔记到飞书文档。 + + 鉴权用「机器人 key」即自建应用凭证:把 LARK_APP_ID / LARK_APP_SECRET 注入子进程环境, + CLI 会自动走 tenant_access_token 流程(与 REST 同一把 key、同一身份),无需交互式登录。 + 适用于后端独立部署(Docker,镜像内 npm 安装好 lark-cli)的场景。 + + 注意:lark-cli 的 `docs +create --markdown` 在部分版本存在「只写首行」的已知问题 + (larksuite/cli issue #82)。本类已用 --format json 解析返回,若你的版本仍截断, + 建议升级 lark-cli,或在「设置 → 飞书推送」把推送方式切回 REST 直连。 + """ + + def __init__(self, config: Optional[Dict[str, Any]] = None): + self.cfg = config or FeishuConfigManager().get_config() + self.cli_path = (self.cfg.get("cli_path") or "lark-cli").strip() or "lark-cli" + self.app_id = (self.cfg.get("app_id") or "").strip() + self.app_secret = (self.cfg.get("app_secret") or "").strip() + self.base_url = (self.cfg.get("base_url") or "https://open.feishu.cn").rstrip("/") + self.folder_token = (self.cfg.get("folder_token") or "").strip() + + # ─── 可用性 / 环境 ─────────────────────────────────────────────────────── + def resolve_cli(self) -> Optional[str]: + """返回 lark-cli 可执行文件的绝对路径;找不到返回 None。""" + return shutil.which(self.cli_path) + + def is_available(self) -> bool: + return self.resolve_cli() is not None + + def _env(self) -> Dict[str, str]: + env = os.environ.copy() + env["LARK_APP_ID"] = self.app_id + env["LARK_APP_SECRET"] = self.app_secret + # 海外 Lark 用 larksuite,国内飞书用 feishu;按配置的开放平台域名推断 + env["LARK_DOMAIN"] = "larksuite" if "larksuite" in self.base_url else "feishu" + # 避免 CLI 启交互式 TUI / 等待浏览器 + env["CI"] = "true" + env["NO_COLOR"] = "1" + return env + + def _run(self, args: List[str]) -> subprocess.CompletedProcess: + cli = self.resolve_cli() + if not cli: + raise FeishuError( + f"未找到 lark CLI({self.cli_path})。请在后端环境安装:" + "npm install -g @larksuite/cli,或在「设置 → 飞书推送」把推送方式切回 REST 直连" + ) + if not (self.app_id and self.app_secret): + raise FeishuError("飞书未配置 App ID / App Secret,无法用 lark-cli 推送") + try: + return subprocess.run( + [cli, *args], + env=self._env(), + capture_output=True, + text=True, + timeout=CLI_TIMEOUT, + ) + except FileNotFoundError as exc: + raise FeishuError(f"无法执行 lark-cli:{exc}") from exc + except subprocess.TimeoutExpired as exc: + raise FeishuError("lark-cli 执行超时(文档可能仍在生成中),可稍后重试") from exc + + # ─── 公有方法 ──────────────────────────────────────────────────────────── + def test_connection(self) -> Dict[str, Any]: + """验证:lark-cli 存在且凭证可用。凭证有效性用 REST 换 token 验证(同一把 key)。""" + if not self.is_available(): + raise FeishuError( + f"未找到 lark CLI({self.cli_path})。请安装 @larksuite/cli,或改用 REST 直连" + ) + # 与 CLI 同一把 app key:用 REST 换一次 token 即可确认凭证有效,不依赖 CLI 的登录态命令 + FeishuService(self.cfg)._get_tenant_access_token() + return {"success": True, "message": "lark-cli 已就绪,凭证有效"} + + def push_markdown( + self, + title: str, + markdown: str, + image_base_url: Optional[str] = None, + ) -> Dict[str, Any]: + if not (markdown or "").strip(): + raise FeishuError("笔记内容为空,无法推送") + + safe_title = FeishuService._safe_title(title) + prepared = FeishuService._prepare_markdown(markdown, image_base_url) + + args = [ + "docs", "+create", + "--title", safe_title, + "--markdown", prepared, + "--format", "json", + ] + # 指定目标文件夹(CLI 不同版本 flag 名可能不同,带上 folder token 尽量挂到指定目录) + if self.folder_token: + args += ["--folder-token", self.folder_token] + + proc = self._run(args) + if proc.returncode != 0: + err = (proc.stderr or proc.stdout or "").strip() + raise FeishuError(f"lark-cli 推送失败:{err[:400] or '未知错误'}") + + result = self._parse_output(proc.stdout, safe_title) + logger.info(f"lark-cli 推送成功:{safe_title} -> {result.get('url')}") + return result + + # ─── 输出解析 ──────────────────────────────────────────────────────────── + def _parse_output(self, stdout: str, title: str) -> Dict[str, Any]: + """从 lark-cli 的 JSON 输出里抽取文档 url / token,尽量兼容不同版本的字段结构。""" + url, token = "", "" + data: Any = None + text = (stdout or "").strip() + if text: + try: + data = json.loads(text) + except json.JSONDecodeError: + # 退而求其次:从纯文本里正则抓飞书文档链接 + m = re.search(r"https?://[^\s\"']*/(?:docx|docs|wiki)/[A-Za-z0-9]+", text) + if m: + url = m.group(0) + + if data is not None: + url = url or self._deep_find(data, _URL_KEYS, _looks_like_doc_url) or "" + token = self._deep_find(data, _TOKEN_KEYS) or "" + if not url: + # 没有现成 url 字段时,再从整段 JSON 文本兜底找一个文档链接 + m = re.search(r"https?://[^\s\"']*/(?:docx|docs|wiki)/[A-Za-z0-9]+", text) + if m: + url = m.group(0) + + if not url and not token: + raise FeishuError( + "lark-cli 已执行但未能解析出文档链接。" + f"请确认 lark-cli 版本与输出格式(原始输出:{text[:200]})" + ) + return {"url": url, "token": token, "type": "docx", "title": title} + + @staticmethod + def _deep_find(obj: Any, keys: tuple, predicate=None) -> Optional[str]: + """在嵌套 dict/list 里按候选 key 找第一个匹配(可选 predicate 进一步校验)的字符串值。""" + if isinstance(obj, dict): + for k, v in obj.items(): + if isinstance(v, str) and k.lower() in keys and (predicate is None or predicate(v)): + return v + for v in obj.values(): + found = FeishuCliService._deep_find(v, keys, predicate) + if found: + return found + elif isinstance(obj, list): + for item in obj: + found = FeishuCliService._deep_find(item, keys, predicate) + if found: + return found + return None + + +_URL_KEYS = ("url", "doc_url", "document_url", "link", "share_url") +_TOKEN_KEYS = ("token", "doc_token", "document_id", "obj_token", "document_token") + + +def _looks_like_doc_url(value: str) -> bool: + return value.startswith("http") and ("feishu" in value or "larksuite" in value or "lark" in value) diff --git a/backend/app/services/feishu_config_manager.py b/backend/app/services/feishu_config_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..7c6b5512b5bdef526298fbd22a8d0e4529a2af51 --- /dev/null +++ b/backend/app/services/feishu_config_manager.py @@ -0,0 +1,109 @@ +import json +import os +from pathlib import Path +from typing import Any, Dict, Optional + +# 飞书 / Lark 开放平台默认域名。海外租户用 open.larksuite.com, +# 国内租户用 open.feishu.cn(默认)。用户可在设置页切换。 +DEFAULT_FEISHU_BASE_URL = "https://open.feishu.cn" + + +class FeishuConfigManager: + """飞书(Lark)文档推送配置,存 JSON 文件,前端可动态修改。 + + 存自建应用凭证(app_id / app_secret)、目标文件夹 token、是否自动推送等。 + app_secret 属敏感信息,只落本地配置文件(与 cookie 等本地凭证一致), + 返回给前端时通过 get_public_config() 隐去明文。 + """ + + def __init__(self, filepath: str = "config/feishu.json"): + self.path = Path(filepath) + self.path.parent.mkdir(parents=True, exist_ok=True) + + def _read(self) -> Dict[str, Any]: + if not self.path.exists(): + return {} + try: + with self.path.open("r", encoding="utf-8") as f: + return json.load(f) + except Exception: + return {} + + def _write(self, data: Dict[str, Any]): + with self.path.open("w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + def get_config(self) -> Dict[str, Any]: + """内部使用:含 app_secret 明文。""" + data = self._read() + base_url = (data.get("base_url") or "").strip() or os.getenv( + "FEISHU_BASE_URL", DEFAULT_FEISHU_BASE_URL + ) + # 推送引擎:rest=直连开放平台(默认,可靠、同一把 key、无需 lark-cli); + # cli=强制走 lark-cli;auto=有 lark-cli 就用 CLI 否则回退 REST。 + # 默认用 rest:lark-cli 的 docs +create --markdown 在部分版本会把笔记截成首行 + # (issue #82),且截断不报错、auto 无法自动回退,所以让用户显式 opt-in CLI。 + backend = (data.get("push_backend") or "rest").strip().lower() + if backend not in ("auto", "rest", "cli"): + backend = "rest" + return { + "enabled": bool(data.get("enabled", False)), + "auto_push": bool(data.get("auto_push", False)), + "app_id": (data.get("app_id") or "").strip(), + "app_secret": (data.get("app_secret") or "").strip(), + "folder_token": (data.get("folder_token") or "").strip(), + "base_url": base_url.rstrip("/"), + "push_backend": backend, + "cli_path": (data.get("cli_path") or "lark-cli").strip() or "lark-cli", + } + + def get_public_config(self) -> Dict[str, Any]: + """给前端展示:隐去 app_secret 明文,只回 app_secret_set 表示是否已配置。""" + cfg = self.get_config() + has_secret = bool(cfg.pop("app_secret", "")) + cfg["app_secret_set"] = has_secret + cfg["configured"] = bool(cfg["app_id"] and has_secret) + return cfg + + def update_config( + self, + enabled: Optional[bool] = None, + auto_push: Optional[bool] = None, + app_id: Optional[str] = None, + app_secret: Optional[str] = None, + folder_token: Optional[str] = None, + base_url: Optional[str] = None, + push_backend: Optional[str] = None, + cli_path: Optional[str] = None, + ) -> Dict[str, Any]: + data = self._read() + if enabled is not None: + data["enabled"] = bool(enabled) + if auto_push is not None: + data["auto_push"] = bool(auto_push) + if app_id is not None: + data["app_id"] = app_id.strip() + # app_secret 仅在传入非空时覆盖:前端不回显明文,留空 == 不修改, + # 避免「只改了别的字段」时把已存的密钥清空。 + if app_secret is not None and app_secret.strip(): + data["app_secret"] = app_secret.strip() + if folder_token is not None: + data["folder_token"] = folder_token.strip() + if base_url is not None: + data["base_url"] = base_url.strip() + if push_backend is not None: + pb = push_backend.strip().lower() + if pb in ("auto", "rest", "cli"): + data["push_backend"] = pb + if cli_path is not None: + data["cli_path"] = cli_path.strip() + self._write(data) + return self.get_public_config() + + def is_configured(self) -> bool: + cfg = self.get_config() + return bool(cfg["app_id"] and cfg["app_secret"]) + + def is_auto_push_enabled(self) -> bool: + cfg = self.get_config() + return bool(cfg["enabled"] and cfg["auto_push"] and cfg["app_id"] and cfg["app_secret"]) diff --git a/backend/app/services/feishu_pusher.py b/backend/app/services/feishu_pusher.py new file mode 100644 index 0000000000000000000000000000000000000000..b2ca7132bc7f20b825c8d0ef84ebbb7e24cc84d3 --- /dev/null +++ b/backend/app/services/feishu_pusher.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import logging +from typing import Any, Dict, Optional + +from app.services.feishu_cli_service import FeishuCliService +from app.services.feishu_config_manager import FeishuConfigManager +from app.services.feishu_service import FeishuError, FeishuService + +logger = logging.getLogger(__name__) + + +def resolve_backend(cfg: Dict[str, Any]) -> str: + """把配置里的 push_backend 解析成实际使用的引擎:'rest' 或 'cli'。 + + - rest / cli:按用户显式选择; + - auto:环境里有 lark-cli 且配了凭证就走 CLI,否则回退 REST。 + """ + backend = (cfg.get("push_backend") or "auto").lower() + if backend in ("rest", "cli"): + return backend + # auto + try: + if cfg.get("app_id") and cfg.get("app_secret") and FeishuCliService(cfg).is_available(): + return "cli" + except Exception: + pass + return "rest" + + +def push_markdown( + title: str, + markdown: str, + image_base_url: Optional[str] = None, + config: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """按配置选择引擎推送笔记。auto 模式下 CLI 失败会自动回落到 REST。""" + cfg = config or FeishuConfigManager().get_config() + backend = resolve_backend(cfg) + + if backend == "cli": + try: + return FeishuCliService(cfg).push_markdown(title, markdown, image_base_url) + except FeishuError: + # 仅 auto 模式才回落;显式选 cli 时把错误抛给用户,避免「以为走了 CLI 其实没走」 + if (cfg.get("push_backend") or "auto").lower() == "auto": + logger.warning("lark-cli 推送失败,回落到 REST 直连导入") + return FeishuService(cfg).push_markdown(title, markdown, image_base_url) + raise + + return FeishuService(cfg).push_markdown(title, markdown, image_base_url) + + +def test_connection(config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """按当前引擎做连通性/凭证校验,返回 {success, message, backend}。""" + cfg = config or FeishuConfigManager().get_config() + backend = resolve_backend(cfg) + + # 显式选 cli 时 resolve_backend 必返回 'cli'(即便没装),由 CLI 服务给出「未找到 lark-cli」的明确报错 + if backend == "cli": + res = FeishuCliService(cfg).test_connection() + else: + res = FeishuService(cfg).test_connection() + + res["backend"] = backend + return res diff --git a/backend/app/services/feishu_service.py b/backend/app/services/feishu_service.py new file mode 100644 index 0000000000000000000000000000000000000000..7c5c7b99dba84d77884e1dc5d15bb5475ac30876 --- /dev/null +++ b/backend/app/services/feishu_service.py @@ -0,0 +1,293 @@ +from __future__ import annotations + +import json +import logging +import re +import time +from typing import Any, Dict, Optional, Tuple + +import requests + +from app.services.feishu_config_manager import FeishuConfigManager + +logger = logging.getLogger(__name__) + +DEFAULT_TIMEOUT = 30 # seconds +_POLL_INTERVAL = 1.5 # seconds between import-task polls +_POLL_MAX_ATTEMPTS = 40 # ~60s 上限,导入大文档也够用 + +# import_tasks 的 job_status:0=成功,1=初始化中,2=处理中,其余为错误码。 +_JOB_STATUS_SUCCESS = 0 +_JOB_STATUS_IN_PROGRESS = {1, 2} + + +class FeishuError(Exception): + """飞书推送相关错误,message 直接面向用户(前端会原样展示)。""" + + def __init__(self, message: str): + super().__init__(message) + self.message = message + + +class FeishuService: + """飞书(Lark)开放平台客户端:把 Markdown 笔记导入为飞书云文档(docx)。 + + 走「导入任务」接口(drive/v1/import_tasks),由飞书原生把 Markdown 转成 docx, + 标题 / 列表 / 代码块 / 表格等格式保真度最好,比手动拼 block 更可靠。 + + 流程(官方 import-user-guide): + 1. medias/upload_all 上传 .md 源文件 → file_token + 2. import_tasks 创建导入任务(type=docx, file_extension=md)→ ticket + 3. 轮询 import_tasks/{ticket} 直到 job_status=0 → 拿到新文档 url / token + + 鉴权用自建应用的 tenant_access_token(应用身份)。注意:以应用身份创建的文档 + 归应用所有,普通用户要能看到,需要把目标文件夹的协作者加上该应用——这点在前端 + 配置页有说明。 + """ + + # tenant_access_token 进程内缓存:key=(base_url, app_id) -> (token, expire_ts) + _token_cache: Dict[Tuple[str, str], Tuple[str, float]] = {} + # 应用根目录 token 缓存:key=(base_url, app_id) -> folder_token + _root_folder_cache: Dict[Tuple[str, str], str] = {} + + def __init__(self, config: Optional[Dict[str, Any]] = None): + self.cfg = config or FeishuConfigManager().get_config() + self.base_url = (self.cfg.get("base_url") or "https://open.feishu.cn").rstrip("/") + self.app_id = (self.cfg.get("app_id") or "").strip() + self.app_secret = (self.cfg.get("app_secret") or "").strip() + self.folder_token = (self.cfg.get("folder_token") or "").strip() + + @property + def _cache_key(self) -> Tuple[str, str]: + return (self.base_url, self.app_id) + + # ─── 鉴权 ──────────────────────────────────────────────────────────────── + def _get_tenant_access_token(self) -> str: + if not self.app_id or not self.app_secret: + raise FeishuError( + "飞书未配置:请到「设置 → 飞书推送」填写 App ID 与 App Secret" + ) + + cached = FeishuService._token_cache.get(self._cache_key) + if cached and cached[1] > time.time(): + return cached[0] + + url = f"{self.base_url}/open-apis/auth/v3/tenant_access_token/internal" + try: + resp = requests.post( + url, + json={"app_id": self.app_id, "app_secret": self.app_secret}, + timeout=DEFAULT_TIMEOUT, + ) + data = resp.json() + except Exception as exc: + raise FeishuError(f"连接飞书失败:{exc}") from exc + + if data.get("code") != 0: + raise FeishuError( + f"飞书鉴权失败(code={data.get('code')}):{data.get('msg')}。" + "请检查 App ID / App Secret 是否正确、应用是否启用" + ) + + # tenant_access_token / expire 在响应顶层(这个老接口不包在 data 里) + token = data.get("tenant_access_token") + if not token: + raise FeishuError(f"飞书鉴权异常:响应缺少 tenant_access_token({data})") + expire_in = int(data.get("expire", 7200)) + # 提前 5 分钟过期,避开临界点 + FeishuService._token_cache[self._cache_key] = (token, time.time() + expire_in - 300) + return token + + def _auth_headers(self) -> Dict[str, str]: + return {"Authorization": f"Bearer {self._get_tenant_access_token()}"} + + @staticmethod + def _fmt_api_error(prefix: str, payload: Dict[str, Any]) -> str: + """统一格式化飞书业务错误,带上 code/msg 方便用户/开发定位。""" + return f"{prefix}(code={payload.get('code')}):{payload.get('msg')}" + + def _root_folder_token(self) -> str: + """未配置目标文件夹时,取应用云空间根目录 token 兜底。""" + cached = FeishuService._root_folder_cache.get(self._cache_key) + if cached: + return cached + url = f"{self.base_url}/open-apis/drive/explorer/v2/root_folder/meta" + try: + resp = requests.get(url, headers=self._auth_headers(), timeout=DEFAULT_TIMEOUT) + payload = resp.json() + except Exception as exc: + raise FeishuError(f"获取飞书根目录失败:{exc}") from exc + if payload.get("code") != 0: + raise FeishuError( + self._fmt_api_error("获取飞书根目录失败", payload) + + "。建议在「设置 → 飞书推送」直接填写目标文件夹 token" + ) + token = (payload.get("data") or {}).get("token", "") + if not token: + raise FeishuError("飞书根目录 token 为空,请在配置里指定目标文件夹 token") + FeishuService._root_folder_cache[self._cache_key] = token + return token + + # ─── 公有方法 ──────────────────────────────────────────────────────────── + def test_connection(self) -> Dict[str, Any]: + """验证凭证:能成功换取 tenant_access_token 即视为连接成功。""" + self._get_tenant_access_token() + return {"success": True, "message": "飞书连接成功,凭证有效"} + + def push_markdown( + self, + title: str, + markdown: str, + image_base_url: Optional[str] = None, + ) -> Dict[str, Any]: + """把 Markdown 导入为飞书云文档。返回 {url, token, type, title}。 + + :param title: 文档标题(取视频标题) + :param markdown: 笔记 Markdown 正文 + :param image_base_url: 把正文里 /static、/uploads 等相对图片链接补成绝对地址的前缀, + 飞书服务端导入时会按 http(s) 抓图(本机/内网地址抓不到则跳过) + """ + if not (markdown or "").strip(): + raise FeishuError("笔记内容为空,无法推送") + + safe_title = self._safe_title(title) + prepared = self._prepare_markdown(markdown, image_base_url) + content = prepared.encode("utf-8") + + folder_token = self.folder_token or self._root_folder_token() + file_token = self._upload_media(safe_title, content, folder_token) + ticket = self._create_import_task(safe_title, file_token, folder_token) + result = self._poll_import_task(ticket) + + token = result.get("token") or "" + doc_type = result.get("type") or "docx" + url = result.get("url") or self._fallback_doc_url(doc_type, token) + logger.info(f"飞书导入成功:{safe_title} -> {url}") + return {"url": url, "token": token, "type": doc_type, "title": safe_title} + + # ─── 导入流程内部步骤 ───────────────────────────────────────────────────── + def _upload_media(self, title: str, content: bytes, folder_token: str) -> str: + """步骤 1:上传 .md 源文件,拿 file_token。""" + url = f"{self.base_url}/open-apis/drive/v1/medias/upload_all" + file_name = f"{title}.md" + data = { + "file_name": file_name, + "parent_type": "ccm_import_open", # 导入专用素材类型 + "parent_node": folder_token, + "size": str(len(content)), + "extra": json.dumps({"obj_type": "docx", "file_extension": "md"}), + } + files = {"file": (file_name, content, "text/markdown")} + try: + resp = requests.post( + url, + headers=self._auth_headers(), + data=data, + files=files, + timeout=DEFAULT_TIMEOUT, + ) + payload = resp.json() + except Exception as exc: + raise FeishuError(f"上传 Markdown 到飞书失败:{exc}") from exc + if payload.get("code") != 0: + raise FeishuError(self._fmt_api_error("上传 Markdown 到飞书失败", payload)) + file_token = (payload.get("data") or {}).get("file_token") + if not file_token: + raise FeishuError(f"飞书上传异常:响应缺少 file_token({payload})") + return file_token + + def _create_import_task(self, title: str, file_token: str, folder_token: str) -> str: + """步骤 2:创建导入任务(md → docx),拿 ticket。""" + url = f"{self.base_url}/open-apis/drive/v1/import_tasks" + body = { + "file_extension": "md", + "file_token": file_token, + "type": "docx", + "file_name": title, + "point": { + "mount_type": 1, # 1 = 挂载到云空间文件夹 + "mount_key": folder_token, + }, + } + try: + resp = requests.post( + url, + headers={**self._auth_headers(), "Content-Type": "application/json"}, + json=body, + timeout=DEFAULT_TIMEOUT, + ) + payload = resp.json() + except Exception as exc: + raise FeishuError(f"创建飞书导入任务失败:{exc}") from exc + if payload.get("code") != 0: + raise FeishuError(self._fmt_api_error("创建飞书导入任务失败", payload)) + ticket = (payload.get("data") or {}).get("ticket") + if not ticket: + raise FeishuError(f"飞书导入异常:响应缺少 ticket({payload})") + return ticket + + def _poll_import_task(self, ticket: str) -> Dict[str, Any]: + """步骤 3:轮询导入结果,成功返回 result 字典(含 url/token)。""" + url = f"{self.base_url}/open-apis/drive/v1/import_tasks/{ticket}" + last_result: Dict[str, Any] = {} + for _ in range(_POLL_MAX_ATTEMPTS): + try: + resp = requests.get(url, headers=self._auth_headers(), timeout=DEFAULT_TIMEOUT) + payload = resp.json() + except Exception as exc: + raise FeishuError(f"查询飞书导入结果失败:{exc}") from exc + if payload.get("code") != 0: + raise FeishuError(self._fmt_api_error("查询飞书导入结果失败", payload)) + + last_result = (payload.get("data") or {}).get("result") or {} + job_status = last_result.get("job_status") + + if job_status == _JOB_STATUS_SUCCESS and last_result.get("token"): + return last_result + if job_status in _JOB_STATUS_IN_PROGRESS or job_status is None: + time.sleep(_POLL_INTERVAL) + continue + # 其余 job_status 视为失败 + err = last_result.get("job_error_msg") or "未知错误" + raise FeishuError(f"飞书导入失败(job_status={job_status}):{err}") + + # 轮询超时:可能仍在处理,给出可理解的提示 + raise FeishuError( + "飞书导入超时(文档可能仍在生成中)。请稍后到飞书云空间查看," + f"或重试推送(job_status={last_result.get('job_status')})" + ) + + def _fallback_doc_url(self, doc_type: str, token: str) -> str: + """飞书偶尔不回 url,按文档类型 + token 拼一个可访问地址。""" + if not token: + return "" + # open.feishu.cn → 主站 feishu.cn;open.larksuite.com → larksuite.com + host = self.base_url.replace("https://open.", "https://").replace("http://open.", "http://") + return f"{host}/{doc_type}/{token}" + + # ─── Markdown 预处理 ────────────────────────────────────────────────────── + @staticmethod + def _safe_title(title: str, fallback: str = "VideoMemo 笔记") -> str: + """清掉标题里的换行/控制字符,截到 120 字(飞书文档标题有长度限制)。""" + cleaned = re.sub(r"[\r\n\t]", " ", (title or "").strip()) + cleaned = cleaned.strip() + return (cleaned[:120] or fallback) + + @staticmethod + def _prepare_markdown(markdown: str, image_base_url: Optional[str]) -> str: + """把 /static、/uploads 开头的相对图片链接补成绝对地址,便于飞书抓图。 + + 飞书导入时按 http(s) 抓取图片,相对路径它无法解析;补成后端绝对地址后, + 仅当后端对飞书服务端可达(公网部署)时图片才会真正落进文档,本机/内网下 + 飞书抓不到会自动跳过该图,不影响正文导入。 + """ + if not image_base_url: + return markdown + base = image_base_url.rstrip("/") + + def _repl(m: "re.Match[str]") -> str: + alt, path = m.group(1), m.group(2) + return f"![{alt}]({base}{path})" + + # 仅替换 ](/static...) 与 ](/uploads...) 这类站内相对图片 + return re.sub(r"!\[([^\]]*)\]\((/(?:static|uploads)/[^)]+)\)", _repl, markdown) diff --git a/backend/app/services/hot_videos.py b/backend/app/services/hot_videos.py new file mode 100644 index 0000000000000000000000000000000000000000..2059ee1171a2bc4be6c9fc4f52e2e38c2267c9a5 --- /dev/null +++ b/backend/app/services/hot_videos.py @@ -0,0 +1,737 @@ +from __future__ import annotations + +import json +import re +from json import JSONDecodeError +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import asdict, dataclass +from datetime import datetime, timezone +from time import monotonic +from typing import Any, Callable, Literal + +import requests + +HotPlatform = Literal[ + # ── Video ── + "bilibili", "bilibili-hot-search", "youtube", "douyin", "kuaishou", "xiaohongshu", + # ── News / social ── + "weibo", "zhihu", "baidu", "toutiao", "thepaper", "ifeng", + "tieba", "hupu", "tencent", "tencent-hot", + "cankaoxiaoxi", "zaobao", "sputniknewscn", + "chongbuluo", "chongbuluo-hot", "chongbuluo-latest", + "kaopu", "hackernews", "producthunt", + "v2ex", "v2ex-share", "solidot", + "sspai", "coolapk", "douban", "nowcoder", "pcbeta", "pcbeta-windows11", + # ── Finance ── + "wallstreetcn", "wallstreetcn-hot", "wallstreetcn-news", "wallstreetcn-quick", + "cls", "cls-hot", "cls-telegraph", "cls-depth", + "36kr", "36kr-quick", "36kr-renqi", + "jin10", "gelonghui", "xueqiu", "xueqiu-hotstock", + "mktnews", "mktnews-flash", "fastbull", "fastbull-express", "fastbull-news", + # ── IT / dev ── + "ithome", "juejin", "github", "github-trending-today", "freebuf", "aihot", +] +HotPlatformFilter = Literal["all"] | HotPlatform +PlatformStatus = Literal["ok", "error", "unavailable"] + +SUPPORTED_HOT_PLATFORMS: tuple[HotPlatform, ...] = tuple(HotPlatform.__args__) # type: ignore[attr-defined] +CACHE_TTL_SECONDS = 600 +DEFAULT_TIMEOUT_SECONDS = 6 +BILIBILI_POPULAR_API_URL = "https://api.bilibili.com/x/web-interface/popular" +BILIBILI_POPULAR_READER_URL = ( + "https://r.jina.ai/http://r.jina.ai/http://https://www.bilibili.com/v/popular/all" +) +BILIBILI_POPULAR_SNAPSHOT: tuple[dict[str, str], ...] = ( + { + "id": "BV1xkE46PE59", + "title": "参加了九次高考的考生采访", + "cover_url": "https://i1.hdslb.com/bfs/archive/2c51e17d5ab278033691d055015af26638867e5b.jpg@412w_232h_1c_!web-popular.avif", + "author": "自来卷三木", + "hot_score": "275.2万播放", + }, + { + "id": "BV1Z37D6UEYz", + "title": "被 偷 家 了 !", + "cover_url": "https://i1.hdslb.com/bfs/archive/127ba31467e867d7c850176898391dd977cf4c52.jpg@412w_232h_1c_!web-popular.avif", + "author": "野生鱼白", + "hot_score": "116.7万播放", + }, + { + "id": "BV1wE7m67EAF", + "title": "“我在成华大道「拼接遗憾」”", + "cover_url": "https://i1.hdslb.com/bfs/archive/b7a88bfce9322211ddb670d69d549967dba6591e.jpg@412w_232h_1c_!web-popular.avif", + "author": "15万点赞 Dieorite", + "hot_score": "69.1万播放", + }, + { + "id": "BV16kV26TES7", + "title": "当父母说出了正确答案:", + "cover_url": "https://i0.hdslb.com/bfs/archive/982b25c4ef21fb08f77ae933c89f4b99d5b0c86a.jpg@412w_232h_1c_!web-popular.avif", + "author": "百万播放 进击的金厂长", + "hot_score": "636.6万播放", + }, + { + "id": "BV1QzVR63E3L", + "title": "回忆永远都是加分项", + "cover_url": "https://i2.hdslb.com/bfs/archive/f4dc3fe8ea88a450877315e867b4995121d5dbf6.jpg@412w_232h_1c_!web-popular.avif", + "author": "百万播放 央视新闻", + "hot_score": "381万播放", + }, + { + "id": "BV1o47S6UEQy", + "title": "仅粉丝可见的神颜", + "cover_url": "https://i2.hdslb.com/bfs/archive/e6c40e8a1ff00524db7ed3aaf7269a0b772fc266.jpg@412w_232h_1c_!web-popular.avif", + "author": "十一的看脸日记", + "hot_score": "162.2万播放", + }, + { + "id": "BV11wE46hETZ", + "title": "没错,针对的就是日本和菲律宾!", + "cover_url": "https://i1.hdslb.com/bfs/archive/df29614140c555cf5b6ded8663f3289dc532104a.jpg@412w_232h_1c_!web-popular.avif", + "author": "央视军事", + "hot_score": "19.6万播放", + }, + { + "id": "BV1TJE862ESA", + "title": "《女神异闻录4 Revival》预购宣传片", + "cover_url": "https://i1.hdslb.com/bfs/archive/7b600cf3f7dc7a7557ee87005d0ee09b76c64bc6.jpg@412w_232h_1c_!web-popular.avif", + "author": "2万分享 SEGA世嘉官方", + "hot_score": "30.4万播放", + }, + { + "id": "BV1AnEx68Eh5", + "title": "特厨探店|明星开的小面馆,真实水平到底怎么样—八号院儿", + "cover_url": "https://i0.hdslb.com/bfs/archive/ef4fdbb48e2a5d5ea262e50d686a682df48c1727.jpg@412w_232h_1c_!web-popular.avif", + "author": "特厨隋坡", + "hot_score": "98万播放", + }, + { + "id": "BV1kME464EJs", + "title": "小小极客湾,拿下!", + "cover_url": "https://i2.hdslb.com/bfs/archive/ba90a910e391dc0aab61116cff6707a1253fa15d.jpg@412w_232h_1c_!web-popular.avif", + "author": "笔吧评测室", + "hot_score": "62万播放", + }, + { + "id": "BV1AsEx6oE3t", + "title": "给我把桑多涅变回来!", + "cover_url": "https://i2.hdslb.com/bfs/archive/511d7c8790049145678aca03eaac77fd4f8826b6.jpg@412w_232h_1c_!web-popular.avif", + "author": "白银新手", + "hot_score": "10.5万播放", + }, + { + "id": "BV1MPEH6REki", + "title": "菜月昴老师,我还记得你", + "cover_url": "https://i1.hdslb.com/bfs/archive/1f4223d2a450a938cb907407be90e8091333ccc2.jpg@412w_232h_1c_!web-popular.avif", + "author": "人气飙升 完勒Linew", + "hot_score": "22.5万播放", + }, +) + + +@dataclass(frozen=True) +class HotVideoItem: + id: str + platform: HotPlatform + title: str + url: str + cover_url: str = "" + author: str = "" + rank: int = 0 + hot_score: str = "" + source: str = "" + + def to_dict(self) -> dict[str, Any]: + return asdict(self) + + +@dataclass(frozen=True) +class PlatformHotVideoResult: + platform: HotPlatform + status: PlatformStatus + message: str + items: list[HotVideoItem] + + def to_dict(self) -> dict[str, Any]: + return { + "platform": self.platform, + "status": self.status, + "message": self.message, + "items": [item.to_dict() for item in self.items], + } + + +def _now_iso() -> str: + return datetime.now(timezone.utc).astimezone().isoformat(timespec="seconds") + + +def _normalize_image_url(url: Any) -> str: + text = str(url or "").strip() + if text.startswith("//"): + return f"https:{text}" + if text.startswith("http://"): + return f"https://{text.removeprefix('http://')}" + return text + + +def _format_bilibili_views(view_count: Any) -> str: + try: + count = int(view_count or 0) + except (TypeError, ValueError): + count = 0 + if count >= 10000: + return f"{count / 10000:.1f}万播放" + return f"{count}播放" + + +def _map_bilibili_popular_items(payload: dict[str, Any], limit: int) -> list[HotVideoItem]: + rows = ((payload.get("data") or {}).get("list") or [])[:limit] + items: list[HotVideoItem] = [] + for row in rows: + bvid = str(row.get("bvid") or "").strip() + title = str(row.get("title") or "").strip() + if not bvid or not title: + continue + owner = row.get("owner") or {} + stat = row.get("stat") or {} + items.append( + HotVideoItem( + id=bvid, + platform="bilibili", + title=title, + url=f"https://www.bilibili.com/video/{bvid}", + cover_url=_normalize_image_url(row.get("pic")), + author=str(owner.get("name") or "").strip(), + rank=len(items) + 1, + hot_score=_format_bilibili_views(stat.get("view")), + source="bilibili_popular", + ) + ) + return items + + +def _map_bilibili_reader_markdown_items(markdown: str, limit: int) -> list[HotVideoItem]: + lines = markdown.splitlines() + items: list[HotVideoItem] = [] + seen: set[str] = set() + link_pattern = re.compile( + r"\[!\[[^\]]+\]\((?P[^)]*)\)\]\(" + r"(?Phttps://www\.bilibili\.com/video/(?PBV[0-9A-Za-z]+)[^)]*)\)" + ) + for index, line in enumerate(lines): + match = link_pattern.search(line) + if not match: + continue + bvid = match.group("bvid") + if bvid in seen: + continue + details = _next_nonempty_lines(lines, index + 1, count=3) + if not details: + continue + title = details[0] + if not title: + continue + seen.add(bvid) + items.append( + HotVideoItem( + id=bvid, + platform="bilibili", + title=title, + url=f"https://www.bilibili.com/video/{bvid}", + cover_url=_normalize_image_url(match.group("cover")), + author=details[1] if len(details) > 1 else "", + rank=len(items) + 1, + hot_score=_format_reader_bilibili_hot_score(details[2] if len(details) > 2 else ""), + source="bilibili_popular_reader", + ) + ) + if len(items) >= limit: + break + return items + + +def _snapshot_bilibili_hot_items(limit: int) -> list[HotVideoItem]: + return [ + HotVideoItem( + id=row["id"], + platform="bilibili", + title=row["title"], + url=f"https://www.bilibili.com/video/{row['id']}", + cover_url=row["cover_url"], + author=row["author"], + rank=index + 1, + hot_score=row["hot_score"], + source="bilibili_popular_snapshot", + ) + for index, row in enumerate(BILIBILI_POPULAR_SNAPSHOT[:limit]) + ] + + +_CacheEntry = tuple[float, dict[str, Any]] +_CACHE: dict[tuple[str, int], _CacheEntry] = {} + + +def clear_hot_video_cache() -> None: + _CACHE.clear() + + +def _session() -> requests.Session: + session = requests.Session() + session.headers.update( + { + "User-Agent": ( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/125.0.0.0 Safari/537.36" + ), + "Accept-Language": "zh-CN,zh;q=0.9,en;q=0.7", + } + ) + return session + + +def _fetch_bilibili_hot(limit: int) -> PlatformHotVideoResult: + primary_error = "" + try: + response = _session().get( + BILIBILI_POPULAR_API_URL, + params={"ps": limit, "pn": 1}, + timeout=DEFAULT_TIMEOUT_SECONDS, + ) + response.raise_for_status() + payload = response.json() + if payload.get("code") not in (0, None): + raise RuntimeError(str(payload.get("message") or "B 站热点接口返回异常")) + items = _map_bilibili_popular_items(payload, limit=limit) + if items: + return PlatformHotVideoResult( + platform="bilibili", + status="ok", + message="", + items=items, + ) + primary_error = "B 站热点暂时没有可用视频" + except Exception as exc: + primary_error = str(exc) or "B 站热点接口返回异常" + + try: + response = _session().get( + BILIBILI_POPULAR_READER_URL, + timeout=DEFAULT_TIMEOUT_SECONDS, + ) + response.raise_for_status() + items = _map_bilibili_reader_markdown_items(response.text, limit=limit) + except Exception: + items = _snapshot_bilibili_hot_items(limit) + return PlatformHotVideoResult( + platform="bilibili", + status="ok" if items else "error", + message="实时热点源暂不可用,已显示最近热门快照" if items else primary_error, + items=items, + ) + + return PlatformHotVideoResult( + platform="bilibili", + status="ok" if items else "error", + message="官方热点接口暂不可用,已切换备用热点源" if items else primary_error, + items=items, + ) + + +def _fetch_youtube_hot(limit: int) -> PlatformHotVideoResult: + html = _session().get( + "https://www.youtube.com/feed/trending", + timeout=DEFAULT_TIMEOUT_SECONDS, + ).text + items = _map_youtube_trending_html(html, limit=limit) + return PlatformHotVideoResult( + platform="youtube", + status="ok" if items else "error", + message="" if items else "YouTube 热点暂时获取失败,可稍后刷新或手动粘贴链接", + items=items, + ) + + +def _fetch_douyin_hot(limit: int) -> PlatformHotVideoResult: + try: + response = _session().get( + "https://www.douyin.com/aweme/v1/web/hot/search/list/", + params={ + "device_platform": "webapp", + "aid": "6383", + "channel": "channel_pc_web", + "detail_list": "1", + }, + timeout=DEFAULT_TIMEOUT_SECONDS, + ) + response.raise_for_status() + items = _map_douyin_hot_items(response.json(), limit=limit) + except (JSONDecodeError, ValueError) as exc: + return PlatformHotVideoResult( + platform="douyin", + status="error", + message="抖音热点接口返回了风控/空页面,暂时无法解析", + items=[], + ) + return PlatformHotVideoResult( + platform="douyin", + status="ok" if items else "error", + message="" if items else "抖音热点受风控限制,稍后刷新或手动粘贴链接", + items=items, + ) + + +def _fetch_kuaishou_hot(limit: int) -> PlatformHotVideoResult: + return PlatformHotVideoResult( + platform="kuaishou", + status="error", + message="快手热点暂时获取失败,可手动粘贴链接", + items=[], + ) + + +def _fetch_xiaohongshu_hot(limit: int) -> PlatformHotVideoResult: + return PlatformHotVideoResult( + platform="xiaohongshu", + status="unavailable", + message="小红书暂未提供稳定公开视频热点源", + items=[], + ) + + +def _fetch_newsnow_hot(platform_id: HotPlatform, limit: int) -> PlatformHotVideoResult: + import os + api_url = os.getenv("NEWSNOW_API_URL", "https://newsnow.busiyi.world/api/s") + if not api_url.endswith("/api/s"): + api_url = api_url.rstrip("/") + "/api/s" + try: + response = _session().get( + api_url, + params={"id": platform_id}, + timeout=DEFAULT_TIMEOUT_SECONDS, + ) + response.raise_for_status() + payload = response.json() + raw_items = payload.get("items", [])[:limit] + items: list[HotVideoItem] = [] + for index, item in enumerate(raw_items): + title = str(item.get("title") or "").strip() + if not title: + continue + items.append( + HotVideoItem( + id=str(item.get("id") or ""), + platform=platform_id, + title=title, + url=str(item.get("url") or ""), + cover_url="", + author=str(item.get("author") or "").strip(), + rank=index + 1, + hot_score=str(item.get("extra", {}).get("info") or "").strip(), + source="newsnow", + ) + ) + return PlatformHotVideoResult( + platform=platform_id, + status="ok", + message="", + items=items, + ) + except Exception as exc: + return PlatformHotVideoResult( + platform=platform_id, + status="error", + message=f"{platform_id} 热点获取失败: {exc}", + items=[], + ) + + + +# ── HOT_FETCHERS: direct platforms + all others via newsnow ───────────────────── + +def _make_newsnow_fetcher(platform_id: str): + """Create a fetcher for any newsnow-supported platform.""" + def _fetcher(limit: int) -> PlatformHotVideoResult: + return _fetch_newsnow_hot(platform_id, limit) # type: ignore[arg-type] + return _fetcher + +_DIRECT: dict[HotPlatform, Callable[[int], PlatformHotVideoResult]] = { + "bilibili": _fetch_bilibili_hot, + "youtube": _fetch_youtube_hot, + "douyin": _fetch_douyin_hot, + "kuaishou": _fetch_kuaishou_hot, + "xiaohongshu": _fetch_xiaohongshu_hot, +} + +HOT_FETCHERS: dict[HotPlatform, Callable[[int], PlatformHotVideoResult]] = {} +HOT_FETCHERS.update(_DIRECT) +for _pid in SUPPORTED_HOT_PLATFORMS: + if _pid not in HOT_FETCHERS: + HOT_FETCHERS[_pid] = _make_newsnow_fetcher(_pid) + + +def _normalize_limit(limit: int) -> int: + try: + value = int(limit) + except (TypeError, ValueError): + value = 12 + return max(1, min(value, 30)) + + +def _error_result(platform: str, exc: Exception) -> PlatformHotVideoResult: + return PlatformHotVideoResult( + platform=platform, # type: ignore[arg-type] + status="error", + message=str(exc) or "热点暂时获取失败", + items=[], + ) + + +def _get_fetcher(platform: str) -> Callable[[int], PlatformHotVideoResult]: + """Get a fetcher for any platform. Known platforms use dedicated fetchers, + unknown platforms try newsnow API as fallback.""" + if platform in HOT_FETCHERS: + return HOT_FETCHERS[platform] # type: ignore[index] + # Try as a custom newsnow platform ID + return _make_newsnow_fetcher(platform) + + +def fetch_hot_videos(platform: str = "all", limit: int = 12) -> list[PlatformHotVideoResult]: + safe_limit = _normalize_limit(limit) + if platform == "all": + platform_names = list(HOT_FETCHERS.keys()) + else: + platform_names = [platform] + results_by_platform: dict[str, PlatformHotVideoResult] = {} + with ThreadPoolExecutor(max_workers=len(platform_names)) as executor: + future_to_platform = { + executor.submit(_get_fetcher(name), safe_limit): name for name in platform_names + } + for future in as_completed(future_to_platform): + name = future_to_platform[future] + try: + results_by_platform[name] = future.result() + except Exception as exc: + results_by_platform[name] = _error_result(name, exc) + results: list[PlatformHotVideoResult] = [] + for name in platform_names: + try: + results.append(results_by_platform[name]) + except KeyError as exc: + results.append(_error_result(name, exc)) + return results + + +def fetch_hot_video_payload( + platform: str = "all", + limit: int = 12, + *, + force: bool = False, +) -> dict[str, Any]: + safe_limit = _normalize_limit(limit) + key = (platform, safe_limit) + now = monotonic() + if not force: + cached = _CACHE.get(key) + if cached and now - cached[0] < CACHE_TTL_SECONDS: + return cached[1] + + results = fetch_hot_videos(platform=platform, limit=safe_limit) + payload = { + "platform": platform, + "limit": safe_limit, + "generated_at": _now_iso(), + "platforms": [result.to_dict() for result in results], + } + _CACHE[key] = (now, payload) + return payload + + +def _map_youtube_trending_html(html: str, limit: int) -> list[HotVideoItem]: + raw_json = _extract_balanced_json(html, "ytInitialData") + if not raw_json: + return [] + try: + payload = json.loads(raw_json) + except json.JSONDecodeError: + return [] + + items: list[HotVideoItem] = [] + seen: set[str] = set() + for node in _walk_dicts(payload): + renderer = node.get("videoRenderer") + if not isinstance(renderer, dict): + continue + video_id = str(renderer.get("videoId") or "").strip() + title = _first_text(renderer.get("title")) + if not video_id or not title or video_id in seen: + continue + seen.add(video_id) + thumbnails = ((renderer.get("thumbnail") or {}).get("thumbnails") or []) + cover_url = "" + if thumbnails and isinstance(thumbnails[-1], dict): + cover_url = str(thumbnails[-1].get("url") or "") + items.append( + HotVideoItem( + id=video_id, + platform="youtube", + title=title, + url=f"https://www.youtube.com/watch?v={video_id}", + cover_url=cover_url, + author=_first_text(renderer.get("ownerText")), + rank=len(items) + 1, + hot_score=_first_text(renderer.get("shortViewCountText")), + source="youtube_trending", + ) + ) + if len(items) >= limit: + break + return items + + +def _map_douyin_hot_items(payload: dict[str, Any], limit: int) -> list[HotVideoItem]: + items: list[HotVideoItem] = [] + seen: set[str] = set() + for node in _walk_dicts(payload): + aweme_infos = node.get("aweme_infos") + if isinstance(aweme_infos, list): + for aweme in aweme_infos: + if not isinstance(aweme, dict): + continue + item = _douyin_item_from_node( + aweme, + rank=len(items) + 1, + parent_hot_value=node.get("hot_value"), + ) + if item and item.id not in seen: + seen.add(item.id) + items.append(item) + if len(items) >= limit: + return items + + for node in _walk_dicts(payload): + item = _douyin_item_from_node(node, rank=len(items) + 1) + if item and item.id not in seen: + seen.add(item.id) + items.append(item) + if len(items) >= limit: + break + return items + + +def _douyin_item_from_node( + node: dict[str, Any], + rank: int, + parent_hot_value: Any = None, +) -> HotVideoItem | None: + aweme_id = str(node.get("aweme_id") or node.get("group_id") or "").strip() + if not re.fullmatch(r"\d{10,}", aweme_id or ""): + return None + title = str(node.get("desc") or node.get("title") or node.get("word") or "").strip() + if not title: + return None + author = node.get("author") if isinstance(node.get("author"), dict) else {} + video = node.get("video") if isinstance(node.get("video"), dict) else {} + hot_value = parent_hot_value or node.get("hot_value") or node.get("view_count") or "" + return HotVideoItem( + id=aweme_id, + platform="douyin", + title=title, + url=f"https://www.douyin.com/video/{aweme_id}", + cover_url=_pick_cover_from_node(video), + author=str(author.get("nickname") or "").strip(), + rank=rank, + hot_score=f"{hot_value}热度" if hot_value else "", + source="douyin_hot_search", + ) + + +def _next_nonempty_lines(lines: list[str], start: int, count: int) -> list[str]: + values: list[str] = [] + for line in lines[start:]: + text = line.strip() + if not text: + continue + values.append(text) + if len(values) >= count: + break + return values + + +def _format_reader_bilibili_hot_score(text: str) -> str: + parts = [part for part in re.split(r"\s+", text.strip()) if part] + if not parts: + return "" + first = parts[0] + if "播放" in first or not re.search(r"\d", first): + return first + return f"{first}播放" + + +def _first_text(value: Any) -> str: + if isinstance(value, str): + return value.strip() + if isinstance(value, dict): + if isinstance(value.get("simpleText"), str): + return value["simpleText"].strip() + runs = value.get("runs") + if isinstance(runs, list): + return "".join( + str(run.get("text") or "") for run in runs if isinstance(run, dict) + ).strip() + return "" + + +def _walk_dicts(value: Any): + if isinstance(value, dict): + yield value + for child in value.values(): + yield from _walk_dicts(child) + elif isinstance(value, list): + for child in value: + yield from _walk_dicts(child) + + +def _extract_balanced_json(text: str, marker: str) -> str: + start = text.find(marker) + if start < 0: + return "" + brace = text.find("{", start) + if brace < 0: + return "" + depth = 0 + in_string = False + escaped = False + for index in range(brace, len(text)): + char = text[index] + if in_string: + if escaped: + escaped = False + elif char == "\\": + escaped = True + elif char == '"': + in_string = False + continue + if char == '"': + in_string = True + elif char == "{": + depth += 1 + elif char == "}": + depth -= 1 + if depth == 0: + return text[brace : index + 1] + return "" + + +def _pick_cover_from_node(node: dict[str, Any]) -> str: + candidates = [ + node.get("cover"), + node.get("origin_cover"), + node.get("dynamic_cover"), + ] + for candidate in candidates: + if isinstance(candidate, dict): + urls = candidate.get("url_list") or [] + if urls: + return str(urls[0]) + return "" diff --git a/backend/app/services/model.py b/backend/app/services/model.py new file mode 100644 index 0000000000000000000000000000000000000000..5e796b5d0d7f97a1ad412e9d78ce49466a98db98 --- /dev/null +++ b/backend/app/services/model.py @@ -0,0 +1,219 @@ + + +from app.db.model_dao import insert_model, get_all_models, get_model_by_provider_and_name, delete_model +from app.db.provider_dao import get_enabled_providers +from app.enmus.exception import ProviderErrorEnum +from app.exceptions.provider import ProviderError +from app.gpt.gpt_factory import GPTFactory +from app.gpt.provider.OpenAI_compatible_provider import OpenAICompatibleProvider +from app.models.model_config import ModelConfig +from app.services.provider import ProviderService +from app.utils.logger import get_logger + +logger=get_logger(__name__) +class ModelService: + + @staticmethod + def _build_model_config(provider: dict) -> ModelConfig: + return ModelConfig( + api_key=provider["api_key"], + base_url=provider["base_url"], + provider=provider["name"], + model_name='', + name=provider["name"], + ) + + @staticmethod + def get_model_list(provider_id: int, verbose: bool = False): + provider = ProviderService.get_provider_by_id(provider_id) + if not provider: + return [] + + try: + config = ModelService._build_model_config(provider) + gpt = GPTFactory().from_config(config) + models = gpt.list_models() + if verbose: + print(f"[{provider['name']}] 模型列表: {models}") + return models + except Exception as e: + print(f"[{provider['name']}] 获取模型失败: {e}") + return [] + + @staticmethod + def get_all_models(verbose: bool = False): + try: + raw_models = get_all_models() + if verbose: + print(f"所有模型列表: {raw_models}") + return ModelService._format_models(raw_models) + except Exception as e: + print(f"获取所有模型失败: {e}") + return [] + @staticmethod + def get_all_models_safe(verbose: bool = False): + try: + raw_models = get_all_models() + if verbose: + print(f"所有模型列表: {raw_models}") + return ModelService._format_models(raw_models) + except Exception as e: + print(f"获取所有模型失败: {e}") + return [] + @staticmethod + def _format_models(raw_models: list) -> list: + """ + 格式化模型列表 + """ + formatted = [] + for model in raw_models: + formatted.append({ + "id": model.get("id"), + "provider_id": model.get("provider_id"), + "model_name": model.get("model_name"), + "created_at": model.get("created_at", None), # 如果有created_at字段 + }) + return formatted + + @staticmethod + def _extract_remote_models(raw_models) -> list: + if raw_models is None: + return [] + if isinstance(raw_models, dict): + raw_models = raw_models.get("data", raw_models.get("models", raw_models)) + elif hasattr(raw_models, "data"): + raw_models = raw_models.data + + if isinstance(raw_models, list): + return raw_models + return [] + + @staticmethod + def _serialize_remote_model(model) -> dict: + if isinstance(model, dict): + return model + if hasattr(model, "model_dump"): + return model.model_dump() + if hasattr(model, "dict"): + return model.dict() + + model_id = getattr(model, "id", None) + if model_id: + return { + "id": model_id, + "object": getattr(model, "object", "model"), + "created": getattr(model, "created", None), + "owned_by": getattr(model, "owned_by", None), + } + return {} + + @staticmethod + def get_enabled_models_by_provider( provider_id: str|int,): + from app.db.model_dao import get_models_by_provider + + all_models = get_models_by_provider(provider_id) + enabled_models = all_models + return enabled_models + @staticmethod + def get_all_models_by_id(provider_id: str, verbose: bool = False): + try: + provider = ProviderService.get_provider_by_id(provider_id) + + models = ModelService.get_model_list(provider["id"], verbose=verbose) + remote_models = ModelService._extract_remote_models(models) + serializable_models = [ + item + for item in (ModelService._serialize_remote_model(model) for model in remote_models) + if item.get("id") + ] + model_list = { + "models": serializable_models + } + + logger.info(f"[{provider['name']}] 获取模型成功") + return model_list + except Exception as e: + # print(f"[{provider_id}] 获取模型失败: {e}") + logger.error(f"[{provider_id}] 获取模型失败: {e}") + return [] + @staticmethod + def connect_test(id: str, model: str | None = None) -> bool: + """连通性测试:发一条最小化 chat completion。 + + model 优先级: + 1. 调用方显式传入(前端可在「模型选择」UI 里挑一个再测) + 2. DB 中该 provider 已保存的第一个模型 + 3. 都没有 → 抛错让用户先加一个模型 + """ + provider = ProviderService.get_provider_by_id(id) + if not provider: + raise ProviderError( + code=ProviderErrorEnum.NOT_FOUND.code, + message=ProviderErrorEnum.NOT_FOUND.message, + ) + if not provider.get('api_key'): + raise ProviderError( + code=ProviderErrorEnum.NOT_FOUND.code, + message=ProviderErrorEnum.NOT_FOUND.message, + ) + + if not model: + saved_models = ModelService.get_enabled_models_by_provider(provider["id"]) + if not saved_models: + raise ProviderError( + code=ProviderErrorEnum.WRONG_PARAMETER.code, + message="请先为该供应商添加至少一个模型再测试连通性", + ) + model = saved_models[0]["model_name"] + + ok = OpenAICompatibleProvider.test_connection( + api_key=provider.get('api_key'), + base_url=provider.get('base_url'), + model=model, + ) + if ok: + return True + raise ProviderError( + code=ProviderErrorEnum.WRONG_PARAMETER.code, + message=ProviderErrorEnum.WRONG_PARAMETER.message, + ) + + + + @staticmethod + def delete_model_by_id( model_id: int) -> bool: + try: + delete_model(model_id) + return True + except Exception as e: + print(f"[{model_id}] : {e}") + return False + @staticmethod + def add_new_model(provider_id: int, model_name: str) -> bool: + try: + # 先查供应商是否存在 + provider = ProviderService.get_provider_by_id(provider_id) + if not provider: + print(f"供应商ID {provider_id} 不存在,无法添加模型") + return False + + # 查询是否已存在同名模型 + existing = get_model_by_provider_and_name(provider_id, model_name) + if existing: + print(f"模型 {model_name} 已存在于供应商ID {provider_id} 下,跳过插入") + return False + + # 插入模型 + insert_model(provider_id=provider_id, model_name=model_name) + print(f"模型 {model_name} 已成功添加到供应商ID {provider_id}") + return True + except Exception as e: + print(f"添加模型失败: {e}") + return False + +if __name__ == '__main__': + # 单个 Provider 测试 + print(ModelService.get_model_list(1, verbose=True)) + + # 所有 Provider 模型测试 + # print(ModelService.get_all_models(verbose=True)) diff --git a/backend/app/services/note.py b/backend/app/services/note.py new file mode 100644 index 0000000000000000000000000000000000000000..f34c95b1ff1371d77a6b971cb70579eca3b89346 --- /dev/null +++ b/backend/app/services/note.py @@ -0,0 +1,1096 @@ +import json +import logging +import os +import time +from dataclasses import asdict +from pathlib import Path +from typing import List, Optional, Tuple, Union, Any + +from fastapi import HTTPException +from pydantic import HttpUrl +from dotenv import load_dotenv + +from app.downloaders.base import Downloader +from app.downloaders.bilibili_downloader import BilibiliDownloader +from app.downloaders.douyin_downloader import DouyinDownloader +from app.downloaders.local_downloader import LocalDownloader +from app.downloaders.youtube_downloader import YoutubeDownloader +from app.db.video_task_dao import delete_task_by_video, insert_video_task +from app.enmus.exception import NoteErrorEnum, ProviderErrorEnum +from app.enmus.task_status_enums import TaskStatus +from app.enmus.note_enums import DownloadQuality +from app.exceptions.note import NoteError +from app.exceptions.provider import ProviderError +from app.gpt.base import GPT +from app.gpt.gpt_factory import GPTFactory +from app.models.audio_model import AudioDownloadResult +from app.models.gpt_model import GPTSource +from app.models.model_config import ModelConfig +from app.models.notes_model import AudioDownloadResult, NoteResult +from app.models.transcriber_model import TranscriptResult, TranscriptSegment +from app.services.constant import SUPPORT_PLATFORM_MAP +from app.services.provider import ProviderService +from app.transcriber.base import Transcriber +from app.transcriber.transcriber_provider import get_transcriber, _transcribers +from app.utils.cover_helper import localize_cover +from app.utils.note_helper import replace_content_markers, prepend_source_link, normalize_toc +from app.utils.path_helper import get_runtime_dir +from app.utils.screenshot_marker import ( + ensure_screenshot_markers, + extract_content_timestamps, + extract_screenshot_timestamps, +) +from app.utils.status_code import StatusCode +from app.utils.video_helper import generate_screenshot +from app.utils.video_reader import VideoReader + +# ------------------ 环境变量与全局配置 ------------------ + +# 从 .env 文件中加载环境变量 +load_dotenv() + +# 后端 API 地址与端口(若有需要可以在代码其他部分使用 BACKEND_BASE_URL) +API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost") +BACKEND_PORT = os.getenv("BACKEND_PORT", "8483") +BACKEND_BASE_URL = f"{API_BASE_URL}:{BACKEND_PORT}" + +# 输出目录(用于缓存音频、转写、Markdown 文件,以及存储截图) +NOTE_OUTPUT_DIR = Path(os.getenv("NOTE_OUTPUT_DIR", "note_results")) +NOTE_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) +# 截图目录必须落在 /static 挂载目录之下(见 main.py 的 static_dir)。仅当 OUT_DIR 显式给出 +# 绝对路径时才采用它;相对值(如默认 .env 里的 ./static/screenshots)一律改用 get_runtime_dir +# 推导,确保打包后与挂载目录同源,避免 cwd 漂移导致截图 404。 +_env_out_dir = os.getenv("OUT_DIR") +IMAGE_OUTPUT_DIR = ( + _env_out_dir + if _env_out_dir and os.path.isabs(_env_out_dir) + else os.path.join(get_runtime_dir("static"), "screenshots") +) +os.makedirs(IMAGE_OUTPUT_DIR, exist_ok=True) +# 图片基础 URL(用于生成 Markdown 中的图片链接,需前端静态目录对应) +IMAGE_BASE_URL = os.getenv("IMAGE_BASE_URL", "/static/screenshots") + +# 日志配置 +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +# ─── 下载失败的友好提示(Cookie 相关) ─────────────────────────────── +# 平台中文名(用于错误提示) +_PLATFORM_LABELS = { + "bilibili": "B站", + "youtube": "YouTube", + "douyin": "抖音", + "kuaishou": "快手", + "xiaohongshu": "小红书", +} + +# 这些平台没配 Cookie 时下载/解析大概率失败。抖音现在走移动端分享页 SSR, +# 不再依赖 Cookie 签名接口,解析失败时不要误导用户去补 Cookie。 +_COOKIE_RECOMMENDED_PLATFORMS = {"kuaishou", "xiaohongshu", "bilibili"} + +# 报错信息中出现这些特征词时,多半与登录态 / Cookie 有关 +_COOKIE_ERROR_SIGNALS = ( + "cookie", "Cookie", "COOKIE", + "登录", "登陆", "login", "Login", "Sign in", "sign in", + "大会员", "Premium", "members-only", "member only", + "412", "403", "Forbidden", "风控", +) + +# 已翻译过的提示都带这个标记,避免二次翻译 +_COOKIE_HINT_MARK = "「设置 → 下载配置」" + + +def friendly_download_error(exc: Exception, platform: str) -> str: + """把下载/解析阶段的报错翻译成用户能行动的提示。 + + - 平台未设置 Cookie,且(平台强依赖 Cookie 或报错带登录特征)→ 提示去配置 Cookie; + - 已设置 Cookie 但报错带登录特征 → 提示 Cookie 可能已失效; + - 其余情况原样返回,并保留原始错误便于排查。 + """ + raw = str(getattr(exc, "message", None) or exc) + if _COOKIE_HINT_MARK in raw: # 幂等:已翻译过的不再处理 + return raw + + try: + from app.services.cookie_manager import CookieConfigManager + cfm = CookieConfigManager() + configured = bool(cfm.get(platform) or cfm.get_browser(platform)) + except Exception: + configured = True # 读不到配置时宁可不提示,也不要误报「未设置」 + + label = _PLATFORM_LABELS.get(platform, platform or "该平台") + has_signal = any(s in raw for s in _COOKIE_ERROR_SIGNALS) + + if not configured and (platform in _COOKIE_RECOMMENDED_PLATFORMS or has_signal): + return ( + f"{label}下载/解析失败,可能是未设置 Cookie 导致:请到「设置 → 下载配置」" + f"为{label}配置 Cookie 后重试。(原始错误:{raw[:300]})" + ) + if configured and has_signal: + return ( + f"{label}的 Cookie 可能已失效或权限不足:请到「设置 → 下载配置」" + f"更新{label}的 Cookie 后重试。(原始错误:{raw[:300]})" + ) + return raw + + +class NoteGenerator: + """ + NoteGenerator 用于执行视频/音频下载、转写、GPT 生成笔记、插入截图/链接、 + 以及将任务信息写入状态文件与数据库等功能。 + """ + + def __init__(self): + from app.services.transcriber_config_manager import TranscriberConfigManager + config_manager = TranscriberConfigManager() + cfg = config_manager.get_config() + self.model_size: str = cfg["whisper_model_size"] + self.device: Optional[str] = None + self.transcriber_type: str = cfg["transcriber_type"] + self.funasr_model: str = cfg.get("funasr_model") or "paraformer-zh" + self._transcriber: Optional[Transcriber] = None + self.video_path: Optional[Path] = None + self.video_img_urls=[] + logger.info("NoteGenerator 初始化完成") + + @property + def transcriber(self) -> Transcriber: + """懒加载转写器:仅在真正需要转写时才初始化。 + + NoteGenerator 还被用于写任务状态、删除笔记、润色等轻量操作; + 之前在 __init__ 里 eager 初始化,配置了不可用引擎(如 mlx-whisper + 未安装)时,连 /generate_note 写 PENDING 状态都会直接 500。 + """ + if self._transcriber is None: + self._transcriber = self._init_transcriber() + return self._transcriber + + + # ---------------- 公有方法 ---------------- + + def generate( + self, + video_url: Union[str, HttpUrl], + platform: str, + quality: DownloadQuality = DownloadQuality.medium, + task_id: Optional[str] = None, + model_name: Optional[str] = None, + provider_id: Optional[str] = None, + link: bool = False, + screenshot: bool = False, + _format: Optional[List[str]] = None, + style: Optional[str] = None, + extras: Optional[str] = None, + output_path: Optional[str] = None, + video_understanding: bool = False, + video_interval: int = 0, + grid_size: Optional[List[int]] = None, + ) -> NoteResult | None: + """ + 主流程:按步骤依次下载、转写、GPT 总结、截图/链接处理、存库、返回 NoteResult。 + + :param video_url: 视频或音频链接 + :param platform: 平台名称,对应 SUPPORT_PLATFORM_MAP 中的键 + :param quality: 下载音频的质量枚举 + :param task_id: 用于标识本次任务的唯一 ID,亦用于状态文件和缓存文件命名 + :param model_name: GPT 模型名称 + :param provider_id: 模型供应商 ID + :param link: 是否在笔记中插入视频片段链接 + :param screenshot: 是否在笔记中替换 Screenshot 标记为图片 + :param _format: 包含 'link' 或 'screenshot' 等字符串的列表,决定后续处理 + :param style: GPT 生成笔记的风格 + :param extras: 额外参数,传递给 GPT + :param output_path: 下载输出目录(可选) + :param video_understanding: 是否需要视频拼图理解(生成缩略图) + :param video_interval: 视频帧截取间隔(秒),仅在 video_understanding 为 True 时生效 + :param grid_size: 生成缩略图时的网格大小,如 [3, 3] + :return: NoteResult 对象,包含 markdown 文本、转写结果和音频元信息 + """ + if grid_size is None: + grid_size = [] + + try: + logger.info(f"开始生成笔记 (task_id={task_id})") + self._update_status(task_id, TaskStatus.PARSING) + + # 获取下载器与 GPT 实例 + + downloader = self._get_downloader(platform) + gpt = self._get_gpt(model_name, provider_id) + + # 缓存文件路径 + audio_cache_file = NOTE_OUTPUT_DIR / f"{task_id}_audio.json" + transcript_cache_file = NOTE_OUTPUT_DIR / f"{task_id}_transcript.json" + markdown_cache_file = NOTE_OUTPUT_DIR / f"{task_id}_markdown.md" + video_transcript_cache_file = None + + # 1. 获取字幕/转写:优先缓存 → 平台字幕 → 音频转写 + transcript = None + + # 尝试读取缓存 + if transcript_cache_file.exists(): + logger.info(f"检测到转写缓存 ({transcript_cache_file}),尝试读取") + try: + data = json.loads(transcript_cache_file.read_text(encoding="utf-8")) + segments = [TranscriptSegment(**seg) for seg in data.get("segments", [])] + transcript = TranscriptResult( + language=data.get("language"), + full_text=data["full_text"], + segments=segments, + ) + logger.info(f"已从缓存加载转写结果,共 {len(segments)} 段") + except Exception as e: + logger.warning(f"加载转写缓存失败: {e}") + + # 缓存没有,尝试获取平台字幕 + if transcript is None: + logger.info("尝试获取平台字幕(优先于音频下载)...") + try: + transcript = downloader.download_subtitles(video_url) + if transcript and transcript.segments: + logger.info(f"成功获取平台字幕,共 {len(transcript.segments)} 段") + transcript_cache_file.write_text( + json.dumps(asdict(transcript), ensure_ascii=False, indent=2), + encoding="utf-8", + ) + else: + transcript = None + logger.info("平台无可用字幕,将下载音频后转写") + except Exception as e: + logger.warning(f"获取平台字幕失败: {e},将下载音频后转写") + transcript = None + + # 暂停门(步骤1→2):解析完成后、下载前可暂停 + self._gate(task_id, TaskStatus.PARSING) + + # 2. 下载音频/视频 + # 有字幕时只提取元信息,不下载音视频文件(除非需要截图/视频理解) + has_transcript = transcript is not None + need_full_download = not has_transcript or screenshot or video_understanding + audio_meta = self._download_media( + downloader=downloader, + video_url=video_url, + quality=quality, + audio_cache_file=audio_cache_file, + status_phase=TaskStatus.DOWNLOADING, + platform=platform, + output_path=output_path, + screenshot=screenshot, + video_understanding=video_understanding, + video_interval=video_interval, + grid_size=grid_size, + skip_download=not need_full_download, + ) + + # 封面本地化:B 站封面是 http 直链(桌面端 WebView 按 mixed content 拦截)、 + # 抖音/快手是限时签名 URL(过期 404)。下载到 /static/covers 后存稳定相对路径, + # 失败则保留原始 URL,由前端直链 + 代理兜底。 + if audio_meta.cover_url and str(audio_meta.cover_url).startswith("http"): + local_cover = localize_cover(audio_meta.cover_url, platform) + if local_cover: + audio_meta.cover_url = local_cover + + video_transcript_cache_file = self._video_transcript_cache_path(platform, audio_meta.video_id) + if transcript is not None: + self._write_transcript_cache( + transcript=transcript, + target=video_transcript_cache_file, + log_label="视频级平台字幕缓存", + ) + + # 暂停门(步骤2→3):下载完成后、转写前可暂停 + self._gate(task_id, TaskStatus.DOWNLOADING) + + # 3. 如果前面没拿到字幕,走转写流程 + if transcript is None: + transcript = self._get_transcript( + downloader=downloader, + video_url=video_url, + audio_file=audio_meta.file_path, + transcript_cache_file=transcript_cache_file, + status_phase=TaskStatus.TRANSCRIBING, + task_id=task_id, + # 视频级缓存:同一视频重复生成笔记时复用音频转写结果 + video_cache_file=video_transcript_cache_file, + ) + else: + # 字幕路径:已直接拿到转写文本,无需音频转写。仍显式标记「转写文字」步骤, + # 否则进度会从「下载」直接跳到「总结」,看起来第3、4步一起完成。 + self._update_status(task_id, TaskStatus.TRANSCRIBING, cache="platform_subtitle") + + # 暂停门(步骤3→4):转写完成后、总结前可暂停。 + # 注意:进入总结(第4步)后到第5步之间不再设暂停门——前端会禁用暂停按钮。 + self._gate(task_id, TaskStatus.TRANSCRIBING) + + # 3. GPT 总结 + markdown = self._summarize_text( + audio_meta=audio_meta, + transcript=transcript, + gpt=gpt, + markdown_cache_file=markdown_cache_file, + link=link, + screenshot=screenshot, + formats=_format or [], + style=style, + extras=extras, + video_img_urls=self.video_img_urls, + ) + + # 4. 截图 & 链接替换 + if _format: + markdown = self._post_process_markdown( + markdown=markdown, + video_path=self.video_path, + formats=_format, + audio_meta=audio_meta, + platform=platform, + ) + + # 目录区块确定性整形:LLM 偶尔把 ## 标记抄进目录条目 / 生成嵌套子项 + markdown = normalize_toc(markdown) + markdown = prepend_source_link(markdown, str(video_url)) + + # 5. 保存记录到数据库 + self._update_status(task_id, TaskStatus.SAVING) + self._save_metadata(video_id=audio_meta.video_id, platform=platform, task_id=task_id) + + # 6. 完成 + from app.services import task_control + task_control.clear(task_id) + total_tokens = int(getattr(gpt, "total_tokens", 0) or 0) + self._update_status(task_id, TaskStatus.SUCCESS) + logger.info(f"笔记生成成功 (task_id={task_id}),消耗 token:{total_tokens}") + + # 7. 异步建立向量索引:跨笔记问答需要 chunks,生成成功就顺手把它索引一份, + # 用户不用再去 Knowledge 页面手动点「重建索引」。 + # 失败不影响生成主流程,仅记录日志;用线程后台跑,避免阻塞响应返回。 + try: + import threading + from app.services.vector_store import VectorStoreManager + + def _auto_index(tid: str): + try: + VectorStoreManager().index_task(tid) + logger.info(f"自动建立索引完成 (task_id={tid})") + except Exception as ie: + logger.warning(f"自动建立索引失败 (task_id={tid}):{ie}") + + threading.Thread(target=_auto_index, args=(task_id,), daemon=True).start() + except Exception as ie: + logger.warning(f"调度自动索引失败 (task_id={task_id}):{ie}") + + return NoteResult( + markdown=markdown, + transcript=transcript, + audio_meta=audio_meta, + total_tokens=total_tokens, + ) + + except Exception as exc: + from app.services import task_control + task_control.clear(task_id) + logger.error(f"生成笔记流程异常 (task_id={task_id}):{exc}", exc_info=True) + self._update_status(task_id, TaskStatus.FAILED, message=str(exc)) + return None + + @staticmethod + def delete_note(video_id: str, platform: str) -> int: + """ + 删除数据库中对应 video_id 与 platform 的任务记录 + + :param video_id: 视频 ID + :param platform: 平台标识 + :return: 删除的记录数 + """ + logger.info(f"删除笔记记录 (video_id={video_id}, platform={platform})") + return delete_task_by_video(video_id, platform) + + # ---------------- 私有方法 ---------------- + + def _init_transcriber(self) -> Transcriber: + """ + 根据环境变量 TRANSCRIBER_TYPE 动态获取并实例化转写器 + """ + if self.transcriber_type not in _transcribers: + logger.error(f"未找到支持的转写器:{self.transcriber_type}") + raise Exception(f"不支持的转写器:{self.transcriber_type}") + + logger.info(f"使用转写器:{self.transcriber_type} (model_size={self.model_size})") + # 必须显式传 model_size:不传的话 get_transcriber 会落到环境变量/默认值, + # 「音频转写配置」页选的模型大小就不生效了 + return get_transcriber(transcriber_type=self.transcriber_type, model_size=self.model_size) + + def _get_gpt(self, model_name: Optional[str], provider_id: Optional[str]) -> GPT: + """ + 根据 provider_id 获取对应的 GPT 实例 + :param model_name: GPT 模型名称 + :param provider_id: 供应商 ID + :return: GPT 实例 + """ + provider = ProviderService.get_provider_by_id(provider_id) + if not provider: + logger.error(f"[get_gpt] 未找到模型供应商: provider_id={provider_id}") + raise ProviderError(code=ProviderErrorEnum.NOT_FOUND,message=ProviderErrorEnum.NOT_FOUND.message) + logger.info(f"创建 GPT 实例 {provider_id}") + config = ModelConfig( + api_key=provider["api_key"], + base_url=provider["base_url"], + model_name=model_name, + provider=provider["type"], + name=provider["name"], + ) + return GPTFactory().from_config(config) + + def _get_downloader(self, platform: str) -> Downloader: + """ + 根据平台名称获取对应的下载器实例 + + :param platform: 平台标识,需在 SUPPORT_PLATFORM_MAP 中 + :return: 对应的 Downloader 子类实例 + """ + downloader_cls = SUPPORT_PLATFORM_MAP.get(platform) + logger.debug(f"实例化下载器 - {platform}") + instance = None + if not downloader_cls: + # 兜底:查用户在「下载配置」里登记的自定义平台 + from app.services import custom_platform_manager + from app.downloaders.generic_downloader import GenericYtdlpDownloader + custom = custom_platform_manager.get(platform) + if custom: + logger.info(f"使用自定义平台下载器: {custom['name']} (key={platform})") + return GenericYtdlpDownloader(platform=platform) + logger.error(f"不支持的平台:{platform}") + raise NoteError(code=NoteErrorEnum.PLATFORM_NOT_SUPPORTED.code, + message=NoteErrorEnum.PLATFORM_NOT_SUPPORTED.message) + try: + instance = downloader_cls + except Exception as e: + logger.error(f"实例化下载器失败:{e}") + + + logger.info(f"使用下载器:{downloader_cls.__class__}") + return instance + + def _gate(self, task_id: Optional[str], current_status: TaskStatus) -> None: + """步骤之间的暂停门:若任务被暂停,则停在当前步骤等待,直到被恢复。 + + 暂停期间保留 current_status(不推进到下一步),仅把状态文件标记 paused=true。 + 恢复后清除 paused 标记并返回,让调用方继续执行下一步。 + """ + if not task_id: + return + from app.services import task_control + + paused_written = False + while task_control.is_paused(task_id): + if not paused_written: + self._update_status(task_id, current_status, paused=True) + paused_written = True + time.sleep(1) + + if paused_written: + self._update_status(task_id, current_status, paused=False) + + def _update_status( + self, + task_id: Optional[str], + status: Union[str, TaskStatus], + message: Optional[str] = None, + paused: bool = False, + cache: Optional[str] = None, + ): + """ + 创建或更新 {task_id}.status.json,记录当前任务状态 + + :param task_id: 任务唯一 ID + :param status: TaskStatus 枚举或自定义状态字符串 + :param message: 可选消息,用于记录失败原因等 + :param paused: 是否处于暂停态(保留当前步骤,仅标记暂停) + :param cache: 可选缓存命中标记,如 transcript/platform_subtitle + """ + if not task_id: + return + + NOTE_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + status_file = NOTE_OUTPUT_DIR / f"{task_id}.status.json" + print(f"写入状态文件: {status_file} 当前状态: {status} paused={paused} cache={cache}") + data = {"status": status.value if isinstance(status, TaskStatus) else status, "paused": paused} + if message: + data["message"] = message + if cache: + data["cache"] = cache + + try: + # First create a temporary file + temp_file = status_file.with_suffix('.tmp') + + # Write to temporary file + with temp_file.open('w', encoding='utf-8') as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + # Atomic rename operation + temp_file.replace(status_file) + + print(f"状态文件写入成功: {status_file}") + except Exception as e: + logger.error(f"写入状态文件失败 (task_id={task_id}):{e}") + # Try to write error to file directly as fallback + try: + with status_file.open('w', encoding='utf-8') as f: + f.write(f"Error writing status: {str(e)}") + except: + logger.error(f"写入错误 {e}") + + def _handle_exception(self, task_id, exc): + logger.error(f"任务异常 (task_id={task_id})", exc_info=True) + error_message = getattr(exc, 'detail', str(exc)) + if isinstance(error_message, dict): + try: + error_message = json.dumps(error_message, ensure_ascii=False) + except: + error_message = str(error_message) + self._update_status(task_id, TaskStatus.FAILED, message=error_message) + + def _download_media( + self, + downloader: Downloader, + video_url: Union[str, HttpUrl], + quality: DownloadQuality, + audio_cache_file: Path, + status_phase: TaskStatus, + platform: str, + output_path: Optional[str], + screenshot: bool, + video_understanding: bool, + video_interval: int, + grid_size: List[int], + skip_download: bool = False, + ) -> AudioDownloadResult | None: + """ + 1. 检查音频缓存;若不存在,则根据需要下载音频或视频(若需截图/可视化)。 + 2. 如果需要视频,则先下载视频并生成缩略图集,再下载音频。 + 3. 返回 AudioDownloadResult + + :param downloader: Downloader 实例 + :param video_url: 视频/音频链接 + :param quality: 音频下载质量 + :param audio_cache_file: 本地缓存 JSON 文件路径 + :param status_phase: 对应的状态枚举,如 TaskStatus.DOWNLOADING + :param platform: 平台标识 + :param output_path: 下载输出目录(可为 None) + :param screenshot: 是否需要在笔记中插入截图 + :param video_understanding: 是否需要生成缩略图 + :param video_interval: 视频截帧间隔 + :param grid_size: 缩略图网格尺寸 + :return: AudioDownloadResult 对象 + """ + task_id = audio_cache_file.stem.split("_")[0] + self._update_status(task_id, status_phase) + + # 已有缓存,尝试加载 + if audio_cache_file.exists(): + logger.info(f"检测到音频缓存 ({audio_cache_file}),直接读取") + try: + data = json.loads(audio_cache_file.read_text(encoding="utf-8")) + audio = AudioDownloadResult(**data) + + # 判断是否需要下载/恢复视频 + need_video = screenshot or video_understanding + if need_video: + video_path_str = data.get("video_path") + # 尝试推导视频路径 + if not video_path_str: + # 兜底1: 针对小红书/抖音等平台,音频路径 file_path 实际上就是下载的 .mp4 视频 + if audio.file_path and audio.file_path.endswith(".mp4"): + video_path_str = audio.file_path + # 兜底2: 猜测与音频同目录下的 .mp4 + elif audio.file_path: + audio_dir = Path(audio.file_path).parent + possible_video = audio_dir / f"{audio.video_id}.mp4" + if possible_video.exists(): + video_path_str = str(possible_video) + + # 检查视频文件是否存在,不存在则重新下载视频 + if video_path_str and Path(video_path_str).exists(): + self.video_path = Path(video_path_str) + logger.info(f"已从缓存恢复视频路径:{self.video_path}") + else: + logger.info("缓存中未找到视频路径或视频文件不存在,重新下载视频") + video_path_str = downloader.download_video(video_url) + self.video_path = Path(video_path_str) + logger.info(f"重新下载视频完成:{self.video_path}") + + # 更新缓存中的 video_path + audio.video_path = str(self.video_path) + try: + audio_cache_file.write_text( + json.dumps(asdict(audio), ensure_ascii=False, indent=2), + encoding="utf-8" + ) + except Exception as cache_err: + logger.warning(f"更新音频缓存中的 video_path 失败: {cache_err}") + + # 如果需要视频理解且指定了 grid_size,在缓存命中时也重新生成/加载 self.video_img_urls + if grid_size: + frame_interval = video_interval if video_interval and video_interval > 0 else 6 + self.video_img_urls = VideoReader( + video_path=str(self.video_path), + grid_size=tuple(grid_size), + frame_interval=frame_interval, + unit_width=960, + unit_height=540, + save_quality=80, + ).run() + + return audio + except Exception as e: + logger.warning(f"读取音频缓存失败,将重新下载:{e}") + + # 有字幕且不需要截图/视频理解时,只提取元信息不下载文件 + if skip_download: + logger.info("已有字幕,仅提取视频元信息(不下载音视频)") + try: + audio = downloader.download( + video_url=video_url, + quality=quality, + output_dir=output_path, + need_video=False, + skip_download=True, + ) + audio_cache_file.write_text( + json.dumps(asdict(audio), ensure_ascii=False, indent=2), + encoding="utf-8", + ) + logger.info(f"元信息提取完成 ({audio_cache_file})") + return audio + except Exception as exc: + logger.warning(f"元信息提取失败,将尝试完整下载: {exc}") + + # 判断是否需要下载视频 + need_video = screenshot or video_understanding + if screenshot and not grid_size: + grid_size = [2, 2] + + frame_interval = video_interval if video_interval and video_interval > 0 else 6 + if need_video: + try: + logger.info("开始下载视频") + video_path_str = downloader.download_video(video_url) + self.video_path = Path(video_path_str) + logger.info(f"视频下载完成:{self.video_path}") + + if grid_size: + self.video_img_urls = VideoReader( + video_path=str(self.video_path), + grid_size=tuple(grid_size), + frame_interval=frame_interval, + unit_width=960, + unit_height=540, + save_quality=80, + ).run() + else: + logger.info("未指定 grid_size,跳过缩略图生成") + except Exception as exc: + logger.error(f"视频下载失败:{exc}") + friendly = friendly_download_error(exc, platform) + self._handle_exception(task_id, RuntimeError(friendly)) + raise RuntimeError(friendly) from exc + + # 下载音频 + try: + logger.info("开始下载音频") + audio = downloader.download( + video_url=video_url, + quality=quality, + output_dir=output_path, + need_video=need_video, + ) + if self.video_path: + audio.video_path = str(self.video_path) + audio_cache_file.write_text(json.dumps(asdict(audio), ensure_ascii=False, indent=2), encoding="utf-8") + logger.info(f"音频下载并缓存成功 ({audio_cache_file})") + return audio + except Exception as exc: + logger.error(f"音频下载失败:{exc}") + friendly = friendly_download_error(exc, platform) + self._handle_exception(task_id, RuntimeError(friendly)) + raise RuntimeError(friendly) from exc + + + def _video_transcript_cache_path(self, platform: str, video_id: Optional[str]) -> Optional[Path]: + """视频级转写缓存路径:以 平台+视频ID+转写引擎+模型 为 key。 + + 转写缓存原本只按 task_id 存,同一个视频每生成一次笔记就要完整重转一遍 + (本地 whisper 一次要数分钟)。音频转写结果在视频维度复用; + key 编入引擎与模型大小,切换转写配置后不会命中旧引擎的结果。 + """ + if not video_id: + return None + # 模型维度按引擎取:whisper 系是档位/自定义,funasr 是其模型名—— + # 否则换 FunASR 模型不会失效旧缓存(曾出现 en 模型的空结果被 zh 复用) + model_key = self.funasr_model if self.transcriber_type == "funasr" else self.model_size + raw_key = f"{platform}_{video_id}_{self.transcriber_type}_{model_key}" + safe_key = "".join(ch if ch.isalnum() or ch in ("-", "_") else "_" for ch in raw_key) + cache_dir = NOTE_OUTPUT_DIR / "video_transcripts" + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir / f"{safe_key}.json" + + def _write_transcript_cache( + self, + transcript: Optional[TranscriptResult], + target: Optional[Path], + log_label: str, + ) -> None: + """把有效转写结果写入缓存。缓存失败不影响主流程。""" + if not transcript or not target: + return + if not (transcript.segments or (transcript.full_text or "").strip()): + return + try: + target.parent.mkdir(parents=True, exist_ok=True) + target.write_text(json.dumps(asdict(transcript), ensure_ascii=False, indent=2), encoding="utf-8") + logger.info(f"{log_label}写入成功 ({target})") + except Exception as e: + logger.warning(f"{log_label}写入失败(忽略):{e}") + + def _get_transcript( + self, + downloader: Downloader, + video_url: str, + audio_file: str, + transcript_cache_file: Path, + status_phase: TaskStatus, + task_id: Optional[str] = None, + video_cache_file: Optional[Path] = None, + ) -> TranscriptResult | None: + """ + 优先获取平台字幕,没有则 fallback 到音频转写 + + :param downloader: 下载器实例 + :param video_url: 视频链接 + :param audio_file: 音频文件路径(用于 fallback 转写) + :param transcript_cache_file: 缓存文件路径 + :param status_phase: 状态枚举 + :param task_id: 任务 ID + :param video_cache_file: 视频级转写缓存路径(跨任务复用音频转写结果) + :return: TranscriptResult 对象 + """ + self._update_status(task_id, status_phase) + + # 已有缓存,直接返回 + if transcript_cache_file.exists(): + logger.info(f"检测到转写缓存 ({transcript_cache_file}),尝试读取") + try: + data = json.loads(transcript_cache_file.read_text(encoding="utf-8")) + segments = [TranscriptSegment(**seg) for seg in data.get("segments", [])] + return TranscriptResult(language=data.get("language"), full_text=data["full_text"], segments=segments) + except Exception as e: + logger.warning(f"加载转写缓存失败,将重新获取:{e}") + + # 1. 先尝试获取平台字幕 + logger.info("尝试获取平台字幕...") + try: + transcript = downloader.download_subtitles(video_url) + if transcript and transcript.segments: + logger.info(f"成功获取平台字幕,共 {len(transcript.segments)} 段") + # 缓存结果 + payload = json.dumps(asdict(transcript), ensure_ascii=False, indent=2) + transcript_cache_file.write_text(payload, encoding="utf-8") + if video_cache_file: + self._write_transcript_cache(transcript, video_cache_file, "视频级平台字幕缓存") + return transcript + else: + logger.info("平台无可用字幕,将使用音频转写") + except Exception as e: + logger.warning(f"获取平台字幕失败: {e},将使用音频转写") + + # 2. Fallback 到音频转写 + return self._transcribe_audio( + audio_file=audio_file, + transcript_cache_file=transcript_cache_file, + status_phase=status_phase, + video_cache_file=video_cache_file, + ) + + def _transcribe_audio( + self, + audio_file: str, + transcript_cache_file: Path, + status_phase: TaskStatus, + video_cache_file: Optional[Path] = None, + ) -> TranscriptResult | None: + """ + 1. 检查转写缓存(先按 task_id,再按视频级缓存);若存在则尝试加载, + 否则调用转写器生成并缓存。 + 2. 返回 TranscriptResult 对象 + + :param audio_file: 音频文件本地路径 + :param transcript_cache_file: 转写结果缓存路径(按 task_id) + :param status_phase: 对应的状态枚举,如 TaskStatus.TRANSCRIBING + :param video_cache_file: 视频级转写缓存路径(跨任务复用,可空) + :return: TranscriptResult 对象 + """ + task_id = transcript_cache_file.stem.split("_")[0] + self._update_status(task_id, status_phase) + + # 已有缓存,尝试加载 + if transcript_cache_file.exists(): + logger.info(f"检测到转写缓存 ({transcript_cache_file}),尝试读取") + try: + data = json.loads(transcript_cache_file.read_text(encoding="utf-8")) + segments = [TranscriptSegment(**seg) for seg in data.get("segments", [])] + return TranscriptResult(language=data["language"], full_text=data["full_text"], segments=segments) + except Exception as e: + logger.warning(f"加载转写缓存失败,将重新转写:{e}") + + # 视频级缓存:同一视频之前的任务已经转写过,直接复用 + if video_cache_file and video_cache_file.exists(): + logger.info(f"检测到视频级转写缓存 ({video_cache_file}),尝试读取") + try: + raw_text = video_cache_file.read_text(encoding="utf-8") + data = json.loads(raw_text) + segments = [TranscriptSegment(**seg) for seg in data.get("segments", [])] + transcript = TranscriptResult( + language=data.get("language"), + full_text=data["full_text"], + segments=segments, + ) + # 回写本任务的缓存文件,repolish 等按 task_id 读取的流程不受影响 + transcript_cache_file.write_text(raw_text, encoding="utf-8") + self._update_status(task_id, status_phase, cache="transcript") + return transcript + except Exception as e: + logger.warning(f"加载视频级转写缓存失败,将重新转写:{e}") + + # 调用转写器 + try: + logger.info("开始转写音频") + transcript = self.transcriber.transcript(file_path=audio_file) + # 空转写视为失败:空结果一旦入缓存,重试会一直命中空数据并在 GPT 总结阶段 + # 以难懂的 IndexError 崩掉(曾因英文模型转中文视频产出空文本触发)。 + if transcript is None or not (transcript.segments or (transcript.full_text or "").strip()): + raise RuntimeError( + "转写结果为空:可能视频没有人声,或当前转写引擎/模型与视频语言不匹配" + "(例如用英文模型转中文视频)。请到「设置 → 音频转写配置」检查后重试。" + ) + payload = json.dumps(asdict(transcript), ensure_ascii=False, indent=2) + transcript_cache_file.write_text(payload, encoding="utf-8") + self._write_transcript_cache(transcript, video_cache_file, "视频级转写缓存") + logger.info(f"转写并缓存成功 ({transcript_cache_file})") + return transcript + except Exception as exc: + logger.error(f"音频转写失败:{exc}") + self._handle_exception(task_id, exc) + raise + + def repolish( + self, + task_id: str, + style: Optional[str], + extras: Optional[str], + provider_id: str, + model_name: str, + ) -> str: + """ + 基于已生成笔记 + 缓存的 transcript 重新润色一版 markdown: + - 跳过下载、转写、截图替换等重型环节 + - 只调 LLM 拿一段新 markdown 文本返回(由路由层负责追加到版本列表) + """ + note_path = NOTE_OUTPUT_DIR / f"{task_id}.json" + if not note_path.exists(): + raise NoteError( + code=NoteErrorEnum.PLATFORM_NOT_SUPPORTED.code, + message=f"笔记不存在:{task_id}", + ) + data = json.loads(note_path.read_text(encoding="utf-8")) + + audio_meta_d = data.get("audio_meta") or {} + transcript_d = data.get("transcript") or {} + segments_raw = transcript_d.get("segments") or [] + if not segments_raw: + raise NoteError( + code=NoteErrorEnum.PLATFORM_NOT_SUPPORTED.code, + message="缺少转写文本,无法重新润色", + ) + + segments = [] + for s in segments_raw: + try: + segments.append(TranscriptSegment(**s)) + except TypeError: + # 兼容只有 start/end/text 字段的旧记录 + segments.append(TranscriptSegment( + start=float(s.get("start", 0)), + end=float(s.get("end", 0)), + text=s.get("text", ""), + )) + + title = audio_meta_d.get("title") or "笔记" + raw_info = audio_meta_d.get("raw_info") or {} + tags = raw_info.get("tags") if isinstance(raw_info.get("tags"), list) else [] + + gpt = self._get_gpt(model_name, provider_id) + source = GPTSource( + title=title, + segment=segments, + tags=tags, + screenshot=False, + video_img_urls=[], + link=False, + _format=[], + style=style, + extras=extras, + # 用独立 checkpoint_key 避免和原始生成共用 prompt 缓存 + checkpoint_key=f"{task_id}_repolish_{int(time.time())}", + ) + markdown = gpt.summarize(source) + logger.info(f"repolish 完成 task_id={task_id} style={style}") + # 润色版同样做目录区块整形 + return normalize_toc(markdown) + + def _summarize_text( + self, + audio_meta: AudioDownloadResult, + transcript: TranscriptResult, + gpt: GPT, + markdown_cache_file: Path, + link: bool, + screenshot: bool, + formats: List[str], + style: Optional[str], + extras: Optional[str], + video_img_urls: List[str], + ) -> str | None: + """ + 调用 GPT 对转写结果进行总结,生成 Markdown 文本并缓存。 + + :param audio_meta: AudioDownloadResult 元信息 + :param transcript: TranscriptResult 转写结果 + :param gpt: GPT 实例 + :param markdown_cache_file: Markdown 缓存路径 + :param link: 是否在笔记中插入链接 + :param screenshot: 是否在笔记中生成截图占位 + :param formats: 包含 'link' 或 'screenshot' 的列表 + :param style: GPT 输出风格 + :param extras: GPT 额外参数 + :return: 生成的 Markdown 字符串 + """ + # markdown_cache_file 名为 "{task_id}_markdown.md",stem 会带上 "_markdown" 后缀, + # 直接用它会把状态写到 "{task_id}_markdown.status.json"(错误文件),导致前端 + # 轮询的 "{task_id}.status.json" 永远看不到 SUMMARIZING,进度条从转写直接跳到完成。 + task_id = markdown_cache_file.stem + if task_id.endswith("_markdown"): + task_id = task_id[: -len("_markdown")] + self._update_status(task_id, TaskStatus.SUMMARIZING) + + source = GPTSource( + title=audio_meta.title, + segment=transcript.segments, + tags=audio_meta.raw_info.get("tags", []), + screenshot=screenshot, + video_img_urls=video_img_urls, + link=link, + _format=formats, + style=style, + extras=extras, + checkpoint_key=task_id, + ) + + try: + markdown = gpt.summarize(source) + markdown_cache_file.write_text(markdown, encoding="utf-8") + logger.info(f"GPT 总结并缓存成功 ({markdown_cache_file})") + return markdown + except Exception as exc: + logger.error(f"GPT 总结失败:{exc}") + self._handle_exception(task_id, exc) + raise + + def _post_process_markdown( + self, + markdown: str, + video_path: Optional[Path], + formats: List[str], + audio_meta: AudioDownloadResult, + platform: str, + ) -> str: + """ + 对生成的 Markdown 做后期处理:插入截图和/或插入链接。 + + :param markdown: 原始 Markdown 字符串 + :param video_path: 本地视频路径(可为 None) + :param formats: 包含 'link' 或 'screenshot' 的列表 + :param audio_meta: AudioDownloadResult 元信息,用于链接替换 + :param platform: 平台标识,用于链接替换 + :return: 处理后的 Markdown 字符串 + """ + if "screenshot" in formats and video_path: + try: + markdown = ensure_screenshot_markers(markdown, audio_meta.duration) + markdown = self._insert_screenshots(markdown, video_path) + except Exception as exc: + logger.warning("截图插入失败,跳过该步骤") + + if "link" in formats: + try: + markdown = replace_content_markers(markdown, video_id=audio_meta.video_id, platform=platform) + except Exception as e: + logger.warning(f"链接插入失败,跳过该步骤:{e}") + + return markdown + + def _insert_screenshots(self, markdown: str, video_path: Path) -> str: + """ + 扫描 Markdown 文本中所有 Screenshot 标记,并替换为实际生成的截图链接。 + + :param markdown: 含有 *Screenshot-mm:ss 或 Screenshot-[mm:ss] 标记的 Markdown 文本 + :param video_path: 本地视频文件路径 + :return: 替换后的 Markdown 字符串 + """ + matches: List[Tuple[str, int]] = extract_screenshot_timestamps(markdown) + if not matches: + content_times = extract_content_timestamps(markdown) + matches = [(f"*Content-[{ts // 60:02d}:{ts % 60:02d}]", ts) for ts in content_times] + for idx, (marker, ts) in enumerate(matches): + try: + img_path = generate_screenshot(str(video_path), str(IMAGE_OUTPUT_DIR), ts, idx) + filename = Path(img_path).name + # 构建前端可访问的 URL,例如 /static/screenshots/{filename} + img_url = f"{IMAGE_BASE_URL.rstrip('/')}/{filename}" + # 把时间戳写进 alt(「原片 @ mm:ss」),前端据此在截图下方生成 + # 「跳转原片对应时间点」的链接。alt 不显示在页面上,不影响观感。 + alt = f"原片 @ {ts // 60:02d}:{ts % 60:02d}" + if marker.startswith("*Content-"): + markdown = markdown.replace(marker, f"{marker}\n\n![{alt}]({img_url})", 1) + else: + markdown = markdown.replace(marker, f"![{alt}]({img_url})", 1) + except Exception as exc: + logger.error(f"生成截图失败 (timestamp={ts}):{exc}") + # 不再直接返回 None,而是继续处理其他截图,保留原文本 + continue + return markdown + + @staticmethod + def _extract_screenshot_timestamps(markdown: str) -> List[Tuple[str, int]]: + """ + 从 Markdown 文本中提取所有 '*Screenshot-mm:ss' 或 'Screenshot-[mm:ss]' 标记, + 返回 [(原始标记文本, 时间戳秒数), ...] 列表。 + + :param markdown: 原始 Markdown 文本 + :return: 标记与对应时间戳秒数的列表 + """ + return extract_screenshot_timestamps(markdown) + + def _save_metadata(self, video_id: str, platform: str, task_id: str) -> None: + """ + 将生成的笔记任务记录插入数据库 + + :param video_id: 视频 ID + :param platform: 平台标识 + :param task_id: 任务 ID + """ + try: + insert_video_task(video_id=video_id, platform=platform, task_id=task_id) + logger.info(f"已保存任务记录到数据库 (video_id={video_id}, platform={platform}, task_id={task_id})") + except Exception as e: + logger.error(f"保存任务记录失败:{e}") diff --git a/backend/app/services/notification.py b/backend/app/services/notification.py new file mode 100644 index 0000000000000000000000000000000000000000..d8812234fdeec4c7b877d1c89fab69c1775f035c --- /dev/null +++ b/backend/app/services/notification.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +import json +import logging +import smtplib +from email.mime.text import MIMEText +from typing import Any + +import requests + +from app.db.trend_subscription_dao import get_channel + +logger = logging.getLogger(__name__) + +DEFAULT_TIMEOUT = 10 # seconds + + +class NotificationService: + """Dispatches push notifications through configured channels.""" + + def send( + self, + channel_id: int, + title: str, + body: str, + url: str = "", + ) -> dict: + """Send a notification through a specific channel. Returns result dict.""" + channel = get_channel(channel_id) + if channel is None: + return {"success": False, "error": f"Channel {channel_id} not found"} + if not channel.enabled: + return {"success": False, "error": "Channel is disabled"} + + config = json.loads(channel.config or "{}") + + try: + if channel.type == "webhook": + return self._send_webhook(config, title, body, url) + elif channel.type == "bark": + return self._send_bark(config, title, body, url) + elif channel.type == "email": + return self._send_email(config, title, body) + else: + return {"success": False, "error": f"Unknown channel type: {channel.type}"} + except Exception as exc: + logger.exception(f"Notification failed for channel {channel_id}") + return {"success": False, "error": str(exc)} + + def send_batch( + self, + channel_ids: list[int], + title: str, + body: str, + url: str = "", + ) -> list[dict]: + """Send to multiple channels. Returns list of per-channel results.""" + results: list[dict] = [] + for cid in channel_ids: + results.append(self.send(cid, title, body, url)) + return results + + def send_test(self, channel_id: int) -> dict: + """Send a test notification to verify channel config.""" + return self.send( + channel_id=channel_id, + title="🎯 VideoMemo 测试通知", + body="如果你收到这条消息,说明通知通道配置成功!\n\nIf you see this, the notification channel is working!", + url="", + ) + + # ─── Channel implementations ───────────────────────────────────────────────── + + def _send_webhook(self, config: dict, title: str, body: str, url: str) -> dict: + webhook_url = str(config.get("url") or "").strip() + if not webhook_url: + return {"success": False, "error": "Webhook URL is empty"} + + payload: dict[str, Any] = { + "title": title, + "body": body, + } + if url: + payload["url"] = url + + # Support custom payload template + template = config.get("template", "") + if template: + try: + payload = json.loads( + template.replace("{{title}}", json.dumps(title)) + .replace("{{body}}", json.dumps(body)) + .replace("{{url}}", json.dumps(url)) + ) + except json.JSONDecodeError: + pass + + resp = requests.post( + webhook_url, + json=payload, + timeout=DEFAULT_TIMEOUT, + headers={"Content-Type": "application/json"}, + ) + resp.raise_for_status() + return {"success": True, "status_code": resp.status_code} + + def _send_bark(self, config: dict, title: str, body: str, url: str) -> dict: + bark_url = str(config.get("url") or "https://api.day.app/push").strip() + device_key = str(config.get("device_key") or "").strip() + if not device_key: + return {"success": False, "error": "Bark device key is empty"} + + full_url = f"{bark_url.rstrip('/')}/{device_key}" + params: dict[str, str] = { + "title": title, + "body": body, + } + if url: + params["url"] = url + if config.get("sound"): + params["sound"] = config["sound"] + if config.get("group"): + params["group"] = config["group"] + + resp = requests.post(full_url, json=params, timeout=DEFAULT_TIMEOUT) + resp.raise_for_status() + return {"success": True, "status_code": resp.status_code} + + def _send_email(self, config: dict, title: str, body: str) -> dict: + smtp_host = str(config.get("smtp_host") or "").strip() + smtp_port = int(config.get("smtp_port") or 587) + smtp_user = str(config.get("smtp_user") or "").strip() + smtp_password = str(config.get("smtp_password") or "").strip() + to_addr = str(config.get("to") or "").strip() + + if not all([smtp_host, smtp_user, smtp_password, to_addr]): + return {"success": False, "error": "Email config incomplete"} + + msg = MIMEText(body, "plain", "utf-8") + msg["Subject"] = title + msg["From"] = smtp_user + msg["To"] = to_addr + + with smtplib.SMTP(smtp_host, smtp_port, timeout=DEFAULT_TIMEOUT) as server: + server.starttls() + server.login(smtp_user, smtp_password) + server.send_message(msg) + + return {"success": True} diff --git a/backend/app/services/provider.py b/backend/app/services/provider.py new file mode 100644 index 0000000000000000000000000000000000000000..52b670d88238ef105d9858db96b9cc7c1ed6c991 --- /dev/null +++ b/backend/app/services/provider.py @@ -0,0 +1,151 @@ +from fastapi.encoders import jsonable_encoder +from kombu import uuid + +from app.db.models.providers import Provider +from app.db.provider_dao import ( + insert_provider, + get_all_providers, + get_provider_by_name, + get_provider_by_id, + update_provider, + delete_provider, get_enabled_providers, +) +from app.gpt.gpt_factory import GPTFactory +from app.models.model_config import ModelConfig + + +class ProviderService: + + @staticmethod + def serialize_provider(row: Provider) -> dict: + if not row: + return None + row = ProviderService.provider_to_dict(row) + return { + "id": row.get("id"), + "name": row.get("name"), + "logo": row.get("logo"), + "type":row.get("type"), + "enabled": row.get("enabled"), + "base_url": row.get("base_url"), + "api_key": row.get("api_key"), + "created_at": jsonable_encoder(row.get("created_at")), + # "name": row[1], + # "logo": row[2], + # "type": row[3], + # "api_key": row[4], + # "base_url": row[5], + # "enabled": row[6], + # "created_at": row[7], + } + @staticmethod + def serialize_provider_safe(row: Provider) -> dict: + if not row: + return None + row = ProviderService.provider_to_dict(row) + + return { + "id": row.get("id"), + "name": row.get("name"), + "logo": row.get("logo"), + "type":row.get("type"), + "enabled": row.get("enabled"), + "base_url": row.get("base_url"), + "api_key": ProviderService.mask_key(row.get("api_key")), + "created_at": jsonable_encoder(row.get("created_at")), + + # "id": row[0], + # "name": row[1], + # "logo": row[2], + # "type": row[3], + # "api_key": ProviderService.mask_key(row[4]), + # "base_url": row[5], + # "enabled": row[6], + # "created_at": row[7], + } + @staticmethod + def mask_key(key: str) -> str: + if not key or len(key) < 8: + return '*' * len(key) + return key[:4] + '*' * (len(key) - 8) + key[-4:] + @staticmethod + def add_provider( name: str, api_key: str, base_url: str, logo: str, type_: str, enabled: int = 1): + try: + # 内置供应商(type='built-in')只能由 seed 流程写入;API 创建一律落到 'custom', + # 否则历史上出现过批量伪内置脏数据 + if type_ != 'custom': + type_ = 'custom' + existing = get_provider_by_name(name) + if existing is not None: + raise ValueError(f'供应商名称已存在: {name}') + id = uuid().lower() + logo = 'custom' + return insert_provider(id, name, api_key, base_url, logo, type_, enabled) + except Exception as e: + print('创建模式失败',e) + raise + @staticmethod + def provider_to_dict(p: Provider): + return { + "id": p.id, + "name": p.name, + "logo": p.logo, + "type": p.type, + "api_key": p.api_key, + "base_url": p.base_url, + "enabled": p.enabled, + "created_at": p.created_at, + } + @staticmethod + def get_all_providers(): + rows = get_all_providers() + if rows is None: + return [] + + return [ProviderService.serialize_provider(row) for row in rows] if rows else [] + @staticmethod + def get_all_providers_safe(): + rows = get_all_providers() + + return [ProviderService.serialize_provider(row) for row in rows] if (rows) else [] + @staticmethod + def get_provider_by_name(name: str): + row = get_provider_by_name(name) + return ProviderService.serialize_provider(row) + + @staticmethod + def get_provider_by_id(id: str): # 已改为 str 类型 + row = get_provider_by_id(id) + return ProviderService.serialize_provider(row) + + @staticmethod + def get_provider_by_id_safe(id: str): # 已改为 str 类型 + row = get_provider_by_id(id) + return ProviderService.serialize_provider_safe(row) + # all_models.extend(provider['models']) + + @staticmethod + def update_provider(id: str, data: dict)->str | None: + try: + # 过滤掉空值 + filtered_data = {k: v for k, v in data.items() if v is not None and k != 'id'} + # 防御掩码污染:前端展示时 api_key 被 mask_key() 处理过(如 a92f****...2d3a), + # 如果用户未重新输入直接保存,带星号的值不应覆盖原 key。 + if 'api_key' in filtered_data and '*' in str(filtered_data.get('api_key', '')): + filtered_data.pop('api_key') + print('更新模型供应商',filtered_data) + update_provider(id, **filtered_data) + # 获取更新后的供应商信息 + updated_provider = get_provider_by_id(id) + return { + 'id': id, + 'enabled': updated_provider.enabled, + } + + except Exception as e: + print('更新模型供应商失败:',e) + return None + + @staticmethod + def delete_provider(id: str): + return delete_provider(id) diff --git a/backend/app/services/proxy_config_manager.py b/backend/app/services/proxy_config_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..abbc636b8305ae379aa7d899c93d9977b1297599 --- /dev/null +++ b/backend/app/services/proxy_config_manager.py @@ -0,0 +1,60 @@ +import json +import os +from pathlib import Path +from typing import Any, Dict, Optional + + +class ProxyConfigManager: + """全局代理配置,存 JSON 文件,支持前端动态修改。 + + 作用范围:LLM API + 转写 API(Groq 等)+ yt-dlp 视频下载。 + 优先级:配置文件里 enabled=true 的 url > 环境变量 HTTP_PROXY/HTTPS_PROXY/ALL_PROXY。 + 这样桌面端/web 用户在设置页填,docker/服务器部署用环境变量兜底。 + """ + + def __init__(self, filepath: str = "config/proxy.json"): + self.path = Path(filepath) + self.path.parent.mkdir(parents=True, exist_ok=True) + + def _read(self) -> Dict[str, Any]: + if not self.path.exists(): + return {} + try: + with self.path.open("r", encoding="utf-8") as f: + return json.load(f) + except Exception: + return {} + + def _write(self, data: Dict[str, Any]): + with self.path.open("w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + def get_config(self) -> Dict[str, Any]: + data = self._read() + return { + "enabled": bool(data.get("enabled", False)), + "url": data.get("url", "") or "", + } + + def update_config(self, enabled: bool, url: Optional[str] = None) -> Dict[str, Any]: + data = self._read() + data["enabled"] = bool(enabled) + if url is not None: + data["url"] = url.strip() + self._write(data) + return self.get_config() + + def get_proxy_url(self) -> Optional[str]: + """返回当前生效的代理 URL;没有则 None。 + + - 配置文件 enabled=true 且 url 非空 → 用配置的 url + - 否则回退到环境变量(标准的 HTTP_PROXY / HTTPS_PROXY / ALL_PROXY,大小写都认) + """ + cfg = self.get_config() + if cfg["enabled"] and cfg["url"]: + return cfg["url"] + for key in ("HTTPS_PROXY", "https_proxy", "HTTP_PROXY", "http_proxy", "ALL_PROXY", "all_proxy"): + val = os.environ.get(key) + if val: + return val + return None diff --git a/backend/app/services/scheduler.py b/backend/app/services/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..4ff6061a8273a0709404290947d57af49229c385 --- /dev/null +++ b/backend/app/services/scheduler.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import logging +import os +import threading +from typing import Any + +logger = logging.getLogger(__name__) + + +class TrendScheduler: + """Simple background scheduler for periodic trend matching and notification. + + Uses threading.Timer for simplicity. Configured via environment variables: + - TREND_CHECK_INTERVAL_MINUTES: how often to run (default 30) + - TREND_SCHEDULER_ENABLED: set to "false" to disable (default true) + """ + + def __init__(self): + self._timer: threading.Timer | None = None + self._running = False + self._interval_minutes = int(os.getenv("TREND_CHECK_INTERVAL_MINUTES", "30")) + self._enabled = os.getenv("TREND_SCHEDULER_ENABLED", "true").lower() != "false" + + @property + def interval_seconds(self) -> int: + return max(60, self._interval_minutes * 60) # minimum 1 minute + + def start(self) -> None: + if not self._enabled: + logger.info("TrendScheduler 已禁用 (TREND_SCHEDULER_ENABLED=false)") + return + if self._running: + return + self._running = True + logger.info(f"TrendScheduler 已启动,间隔 {self._interval_minutes} 分钟") + self._schedule_next() + + def stop(self) -> None: + self._running = False + if self._timer is not None: + self._timer.cancel() + self._timer = None + logger.info("TrendScheduler 已停止") + + def run_now(self) -> dict[str, Any]: + """Manually trigger a full matching cycle. Returns summary.""" + logger.info("TrendScheduler 手动触发匹配…") + try: + from app.services.trend_subscription import TrendSubscriptionService + from app.services.notification import NotificationService + + svc = TrendSubscriptionService() + summary = svc.match_all_subscriptions() + + # Send notifications for subscriptions with new matches + if summary["total_new_matches"] > 0: + notifier = NotificationService() + for sub_result in summary["by_subscription"]: + sub = svc.get_subscription(sub_result["subscription_id"]) + if not sub or not sub.get("push_enabled"): + continue + channel_ids = sub.get("push_channel_ids", []) + if not channel_ids: + continue + + match_titles = [m["title"] for m in sub_result["matches"]] + title = f"🔥 VideoMemo: {sub['name']} — {len(match_titles)} 条新热点" + body = "\n\n".join(f"• {t}" for t in match_titles[:10]) + if len(match_titles) > 10: + body += f"\n\n…共 {len(match_titles)} 条" + + notifier.send_batch(channel_ids, title, body) + + logger.info( + f"TrendScheduler 匹配完成: " + f"{summary['total_subscriptions']} 订阅, " + f"{summary['total_new_matches']} 新匹配" + ) + return summary + except Exception: + logger.exception("TrendScheduler 匹配出错") + return {"error": "匹配过程出错,详见日志"} + + def _schedule_next(self) -> None: + if not self._running: + return + + def _tick(): + self.run_now() + self._schedule_next() + + self._timer = threading.Timer(self.interval_seconds, _tick) + self._timer.daemon = True + self._timer.start() + + +# Module-level singleton +_scheduler: TrendScheduler | None = None + + +def get_scheduler() -> TrendScheduler: + global _scheduler + if _scheduler is None: + _scheduler = TrendScheduler() + return _scheduler diff --git a/backend/app/services/task_control.py b/backend/app/services/task_control.py new file mode 100644 index 0000000000000000000000000000000000000000..95c5de0ed1fc777f911d4301fb9cd84975da6cf6 --- /dev/null +++ b/backend/app/services/task_control.py @@ -0,0 +1,32 @@ +"""任务暂停控制:进程内共享的暂停标记。 + +NoteGenerator 在步骤之间检查这里的标记;前端通过 /task_control 接口设置。 +仅进程内有效(与现有 ThreadPoolExecutor 执行模型一致)。 +""" +import threading + +_lock = threading.Lock() +_paused: set[str] = set() + + +def pause(task_id: str) -> None: + if not task_id: + return + with _lock: + _paused.add(task_id) + + +def resume(task_id: str) -> None: + with _lock: + _paused.discard(task_id) + + +def is_paused(task_id: str) -> bool: + with _lock: + return task_id in _paused + + +def clear(task_id: str) -> None: + """任务结束(成功/失败)时清理标记。""" + with _lock: + _paused.discard(task_id) diff --git a/backend/app/services/task_serial_executor.py b/backend/app/services/task_serial_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..f4017f9281eb04869ad064a6034efdc54a2a2d16 --- /dev/null +++ b/backend/app/services/task_serial_executor.py @@ -0,0 +1,23 @@ +import os +from concurrent.futures import ThreadPoolExecutor, Future +from typing import Any, Callable + + +class ConcurrentTaskExecutor: + """使用线程池并发执行任务,替代原来的串行锁。""" + + def __init__(self, max_workers: int | None = None): + self._max_workers = max_workers or int(os.getenv("TASK_MAX_WORKERS", "3")) + self._pool = ThreadPoolExecutor(max_workers=self._max_workers) + + def run(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + future: Future = self._pool.submit(fn, *args, **kwargs) + return future.result() + + def shutdown(self, wait: bool = True): + self._pool.shutdown(wait=wait) + + +# 保持向后兼容的导出名 +SerialTaskExecutor = ConcurrentTaskExecutor +task_serial_executor = ConcurrentTaskExecutor() diff --git a/backend/app/services/transcriber_config_manager.py b/backend/app/services/transcriber_config_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..ead81b87b6c2cdb57758f683c35a96c29277febf --- /dev/null +++ b/backend/app/services/transcriber_config_manager.py @@ -0,0 +1,174 @@ +import json +import os +from pathlib import Path +from typing import Optional, Dict, Any + + +class TranscriberConfigManager: + """管理转写器配置,存储在 JSON 文件中,支持前端动态修改。""" + + def __init__(self, filepath: str = "config/transcriber.json"): + self.path = Path(filepath) + self.path.parent.mkdir(parents=True, exist_ok=True) + + def _read(self) -> Dict[str, Any]: + if not self.path.exists(): + return {} + try: + with self.path.open("r", encoding="utf-8") as f: + return json.load(f) + except Exception: + return {} + + def _write(self, data: Dict[str, Any]): + with self.path.open("w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + def get_config(self) -> Dict[str, Any]: + """获取当前转写器配置,fallback 到环境变量默认值。 + + whisper 默认 size 从 'medium' (~1.5GB) 改为 'tiny' (~75MB): + 新装用户没主动设置时不应该被首次下载卡住。想要更高精度可在「音频转写配置」 + 页主动切换。 + """ + data = self._read() + ttype = data.get( + "transcriber_type", + os.getenv("TRANSCRIBER_TYPE", "fast-whisper"), + ) + size = data.get( + "whisper_model_size", + os.getenv("WHISPER_MODEL_SIZE", "tiny"), + ) + # 防御:存储/环境变量里的值不在可选列表时回退到第一个, + # 避免前端下拉框初始化为空或指向不存在的引擎/模型 + if ttype not in ("fast-whisper", "bcut", "kuaishou", "groq", "mlx-whisper", "funasr"): + ttype = "fast-whisper" + # "custom" 表示用户自定义本地/HF whisper 模型(路径见 whisper_custom_model) + if size not in ("tiny", "base", "small", "medium", "large-v3", "large-v3-turbo", "custom"): + size = "tiny" + return { + "transcriber_type": ttype, + "whisper_model_size": size, + # 自定义 whisper 模型:本地 CTranslate2 目录 或 HF 仓库 id + "whisper_custom_model": (data.get("whisper_custom_model") or "").strip(), + # FunASR 模型名/路径(modelscope id 或本地目录),默认中文 paraformer-zh + "funasr_model": (data.get("funasr_model") or "paraformer-zh").strip(), + } + + def update_config( + self, + transcriber_type: str, + whisper_model_size: Optional[str] = None, + whisper_custom_model: Optional[str] = None, + funasr_model: Optional[str] = None, + ) -> Dict[str, Any]: + """更新转写器配置并持久化。""" + data = self._read() + data["transcriber_type"] = transcriber_type + if whisper_model_size is not None: + data["whisper_model_size"] = whisper_model_size + if whisper_custom_model is not None: + data["whisper_custom_model"] = whisper_custom_model.strip() + if funasr_model is not None: + data["funasr_model"] = funasr_model.strip() + self._write(data) + return self.get_config() + + def get_transcriber_type(self) -> str: + return self.get_config()["transcriber_type"] + + def get_whisper_model_size(self) -> str: + return self.get_config()["whisper_model_size"] + + def is_model_ready(self) -> Dict[str, Any]: + """当前转写器是否就绪可用。 + + 返回 {ready, transcriber_type, model_size, downloading, reason}: + - 在线引擎 (groq/bcut/kuaishou):永远 ready(不需要本地模型) + - fast-whisper:检查 whisper-{size}/model.bin 落盘 + - mlx-whisper:检查 {repo_id}/config.json 落盘 + 给 /generate_note 入口做「开始视频前先确认模型下载好」的门禁用。 + """ + cfg = self.get_config() + ttype = cfg["transcriber_type"] + size = cfg["whisper_model_size"] + result = { + "ready": True, + "transcriber_type": ttype, + "model_size": size, + "downloading": False, + "reason": "", + } + # FunASR:可选引擎,需安装 funasr+torch。模型经 modelscope 首跑自动下载, + # 不做预下载门禁,只确认引擎可用,否则给安装指引。 + if ttype == "funasr": + try: + from app.transcriber.transcriber_provider import FUNASR_AVAILABLE + except Exception: + FUNASR_AVAILABLE = True # 检查不了就放行,交给后续流程报错 + if not FUNASR_AVAILABLE: + result["ready"] = False + result["reason"] = ( + "FunASR 引擎当前不可用(未安装)。请安装依赖:" + "pip install funasr torch torchaudio,安装后重启后端;或切换到其他转写引擎。" + ) + return result + + if ttype not in ("fast-whisper", "mlx-whisper"): + return result # 在线引擎无需本地模型 + + # fast-whisper 自定义模型:路径/仓库 id 由用户自负,本地目录存在即就绪; + # 仓库 id 也放行(首跑联网下载),不进预设档位的下载门禁。 + if ttype == "fast-whisper" and size == "custom": + custom = cfg.get("whisper_custom_model") or "" + if not custom: + result["ready"] = False + result["reason"] = "已选「自定义」Whisper 模型,但未填写模型路径或仓库 id。" + return result + + # mlx-whisper 还要求引擎本身可用(包已安装且原生库能加载)。 + # 配置可能是在引擎可用时保存的,之后换了环境/重装应用就失效了—— + # 在这里拦下并给出可行动的指引,而不是让 NoteGenerator 初始化时 500。 + if ttype == "mlx-whisper": + try: + from app.transcriber.transcriber_provider import MLX_WHISPER_AVAILABLE + except Exception: + MLX_WHISPER_AVAILABLE = True # 检查不了就放行,交给后续流程报错 + if not MLX_WHISPER_AVAILABLE: + result["ready"] = False + result["reason"] = ( + "MLX Whisper 引擎当前不可用(未安装或本机不支持)。" + "请到「设置 → 音频转写配置」按页面提示安装 mlx_whisper 后重启应用," + "或切换到其他转写引擎。" + ) + return result + + # 延迟 import 避免与 routers.config 的循环依赖;只取纯函数,不触发路由副作用 + try: + from app.routers.config import ( + _check_whisper_model_exists, + _check_mlx_whisper_model_exists, + _downloading, + ) + except Exception as e: + # 拿不到检查函数时保守放行,不要把用户卡死 + result["reason"] = f"无法检查模型状态: {e}" + return result + + if ttype == "fast-whisper": + downloaded = _check_whisper_model_exists(size, "whisper") + downloading = _downloading.get(size) == "downloading" + else: # mlx-whisper + downloaded = _check_mlx_whisper_model_exists(size) + downloading = _downloading.get(f"mlx-{size}") == "downloading" + + result["downloading"] = downloading + if downloaded: + return result + result["ready"] = False + result["reason"] = ( + f"转写模型 {ttype} / {size} 尚未下载就绪" + + (",正在下载中,请稍候" if downloading else ",请先在「设置 → 音频转写配置」页下载") + ) + return result diff --git a/backend/app/services/trend_subscription.py b/backend/app/services/trend_subscription.py new file mode 100644 index 0000000000000000000000000000000000000000..d9a009a15f7a93017bd98702390e8eebc0db5123 --- /dev/null +++ b/backend/app/services/trend_subscription.py @@ -0,0 +1,282 @@ +from __future__ import annotations + +import json +import re +from typing import Any + +from app.db.trend_subscription_dao import ( + count_unread_matches, + create_match, + create_subscription, + delete_subscription, + get_subscription, + list_matches, + list_subscriptions, + mark_matches_read, + update_subscription, + update_subscription_refresh, +) +from app.services.hot_videos import HotVideoItem, fetch_hot_videos + + +def _parse_keywords(raw: Any) -> list[str]: + """Parse keywords from stored JSON or list.""" + if isinstance(raw, list): + return [str(k).strip() for k in raw if str(k).strip()] + if isinstance(raw, str): + try: + parsed = json.loads(raw) + return _parse_keywords(parsed) + except (json.JSONDecodeError, TypeError): + return [k.strip() for k in raw.split(",") if k.strip()] + return [] + + +def _parse_platforms(raw: Any) -> list[str]: + """Parse platform list from stored JSON or string.""" + if isinstance(raw, list): + return [str(p).strip() for p in raw if str(p).strip()] + if isinstance(raw, str): + try: + parsed = json.loads(raw) + return _parse_platforms(parsed) + except (json.JSONDecodeError, TypeError): + return [p.strip() for p in raw.split(",") if p.strip()] + return ["all"] + + +def _match_keywords(title: str, keywords: list[str], mode: str = "any") -> tuple[bool, list[str]]: + """Match title against keywords. Returns (matched, [matched_keyword_strings]). + + Keyword syntax (inspired by TrendRadar): + - Plain keyword: case-insensitive substring match + - +keyword: must-have (required) + - -keyword: exclude (if matched, the item is rejected entirely) + - /pattern/: regex match + """ + if not keywords: + # Empty keywords = match all items (subscribe to entire platform) + return True, ["*"] + + title_lower = title.lower() + must_haves: list[str] = [] + excludes: list[str] = [] + normal: list[str] = [] + + for kw in keywords: + kw = kw.strip() + if not kw: + continue + if kw.startswith("+"): + must_haves.append(kw[1:]) + elif kw.startswith("-"): + excludes.append(kw[1:]) + else: + normal.append(kw) + + # Check excludes first — if any exclude matches, reject immediately + for ex in excludes: + ex_lower = ex.lower() + if ex_lower in title_lower: + return False, [] + + all_keywords = must_haves + normal + if not all_keywords: + return False, [] + + matched: list[str] = [] + for kw in all_keywords: + kw_lower = kw.lower() + # Try regex if wrapped in /slashes/ + if kw.startswith("/") and kw.endswith("/") and len(kw) > 2: + try: + if re.search(kw[1:-1], title, re.IGNORECASE): + matched.append(kw) + except re.error: + pass + elif kw_lower in title_lower: + matched.append(kw) + + if mode == "all": + return len(matched) == len(all_keywords), matched + # mode == "any" + return len(matched) > 0, matched + + +class TrendSubscriptionService: + """Service for managing trend keyword subscriptions and matching hot items.""" + + # ─── Subscription CRUD ─────────────────────────────────────────────────────── + + def list_subscriptions(self) -> list[dict]: + subs = list_subscriptions() + result: list[dict] = [] + for sub in subs: + d = self._sub_to_dict(sub) + d["unread_count"] = count_unread_matches(sub.id) + result.append(d) + return result + + def get_subscription(self, subscription_id: int) -> dict | None: + sub = get_subscription(subscription_id) + if sub is None: + return None + d = self._sub_to_dict(sub) + d["unread_count"] = count_unread_matches(sub.id) + return d + + def create_subscription( + self, + name: str, + keywords: list[str], + platforms: list[str] | None = None, + match_mode: str = "any", + push_enabled: bool = False, + push_channel_ids: list[int] | None = None, + ) -> dict: + sub = create_subscription( + name=name, + keywords=keywords, + platforms=platforms, + match_mode=match_mode, + push_enabled=push_enabled, + push_channel_ids=push_channel_ids, + ) + return self._sub_to_dict(sub) + + def update_subscription( + self, + subscription_id: int, + name: str | None = None, + keywords: list[str] | None = None, + platforms: list[str] | None = None, + match_mode: str | None = None, + enabled: bool | None = None, + push_enabled: bool | None = None, + push_channel_ids: list[int] | None = None, + ) -> dict | None: + sub = update_subscription( + subscription_id=subscription_id, + name=name, + keywords=keywords, + platforms=platforms, + match_mode=match_mode, + enabled=enabled, + push_enabled=push_enabled, + push_channel_ids=push_channel_ids, + ) + return self._sub_to_dict(sub) if sub else None + + def delete_subscription(self, subscription_id: int) -> bool: + return delete_subscription(subscription_id) + + # ─── Matching ──────────────────────────────────────────────────────────────── + + def match_subscription(self, subscription_id: int) -> dict: + """Fetch hot items and match against this subscription. Returns summary.""" + sub = get_subscription(subscription_id) + if sub is None: + raise ValueError(f"Subscription {subscription_id} not found") + + keywords = _parse_keywords(sub.keywords) + platforms = _parse_platforms(sub.platforms) + match_mode = sub.match_mode or "any" + + new_matches: list[dict] = [] + # Fetch from each platform + for platform in platforms: + try: + results = fetch_hot_videos(platform=platform, limit=20) + except Exception: + continue + + for result in results: + if result.status != "ok": + continue + for item in result.items: + matched, matched_kws = _match_keywords(item.title, keywords, match_mode) + if matched: + match = create_match( + subscription_id=subscription_id, + platform=item.platform, + item_id=item.id, + title=item.title, + url=item.url, + hot_score=item.hot_score, + matched_keywords=matched_kws, + ) + if match is not None: + new_matches.append(self._match_to_dict(match)) + + update_subscription_refresh(subscription_id) + return { + "subscription_id": subscription_id, + "new_matches": len(new_matches), + "matches": new_matches, + } + + def match_all_subscriptions(self) -> dict: + """Match all enabled subscriptions. Returns summary for notifications.""" + subs = list_subscriptions() + summary: dict[str, Any] = {"total_subscriptions": 0, "total_new_matches": 0, "by_subscription": []} + + for sub in subs: + if not sub.enabled: + continue + summary["total_subscriptions"] += 1 + try: + result = self.match_subscription(sub.id) + if result["new_matches"] > 0: + summary["by_subscription"].append(result) + summary["total_new_matches"] += result["new_matches"] + except Exception: + continue + + return summary + + # ─── Matches ───────────────────────────────────────────────────────────────── + + def list_matches( + self, + subscription_id: int | None = None, + limit: int = 100, + unread_only: bool = False, + ) -> list[dict]: + matches = list_matches(subscription_id=subscription_id, limit=limit, unread_only=unread_only) + return [self._match_to_dict(m) for m in matches] + + def mark_all_read(self, subscription_id: int) -> int: + return mark_matches_read(subscription_id) + + # ─── Helpers ───────────────────────────────────────────────────────────────── + + @staticmethod + def _sub_to_dict(sub) -> dict: + return { + "id": sub.id, + "name": sub.name, + "keywords": _parse_keywords(sub.keywords), + "platforms": _parse_platforms(sub.platforms), + "match_mode": sub.match_mode, + "enabled": sub.enabled, + "push_enabled": sub.push_enabled, + "push_channel_ids": json.loads(sub.push_channel_ids or "[]"), + "last_matched_at": sub.last_matched_at.isoformat() if sub.last_matched_at else None, + "created_at": sub.created_at.isoformat() if sub.created_at else None, + "updated_at": sub.updated_at.isoformat() if sub.updated_at else None, + } + + @staticmethod + def _match_to_dict(m) -> dict: + return { + "id": m.id, + "subscription_id": m.subscription_id, + "platform": m.platform, + "item_id": m.item_id, + "title": m.title, + "url": m.url, + "hot_score": m.hot_score, + "matched_keywords": json.loads(m.matched_keywords or "[]"), + "matched_at": m.matched_at.isoformat() if m.matched_at else None, + "is_read": m.is_read, + } diff --git a/backend/app/services/vector_store.py b/backend/app/services/vector_store.py new file mode 100644 index 0000000000000000000000000000000000000000..a4da26fe7391b6dbb053708686038b57933bbc02 --- /dev/null +++ b/backend/app/services/vector_store.py @@ -0,0 +1,267 @@ +import json +import os +import re +from typing import Optional + +import chromadb +from chromadb.config import Settings + +from app.utils.logger import get_logger + +logger = get_logger(__name__) + +NOTE_OUTPUT_DIR = os.getenv("NOTE_OUTPUT_DIR", "note_results") +VECTOR_DB_DIR = os.getenv("VECTOR_DB_DIR", "vector_db") + + +def _chunk_markdown(markdown: str) -> list[dict]: + """按 H2/H3 标题拆分 markdown 为语义块。""" + sections = re.split(r'(?=^#{2,3}\s)', markdown, flags=re.MULTILINE) + chunks = [] + for section in sections: + section = section.strip() + if not section or len(section) < 30: + continue + heading_match = re.match(r'^(#{2,3})\s+(.+)', section) + title = heading_match.group(2).strip() if heading_match else "intro" + chunks.append({ + "text": section, + "metadata": {"source_type": "markdown", "section_title": title}, + }) + return chunks + + +def _chunk_transcript(segments: list[dict], window_size: int = 15, overlap: int = 3) -> list[dict]: + """将转录 segments 按滑动窗口分组。""" + if not segments: + return [] + chunks = [] + step = max(window_size - overlap, 1) + for i in range(0, len(segments), step): + window = segments[i:i + window_size] + if not window: + break + text = "\n".join( + f"[{seg.get('start', 0):.0f}s] {seg.get('text', '')}" for seg in window + ) + chunks.append({ + "text": text, + "metadata": { + "source_type": "transcript", + "start_time": window[0].get("start", 0), + "end_time": window[-1].get("end", 0), + }, + }) + return chunks + + +def _build_meta_chunk(audio_meta: dict) -> list[dict]: + """将视频元信息(标题、作者、描述、标签等)构建为可检索的 chunk。""" + if not audio_meta: + return [] + + raw = audio_meta.get("raw_info", {}) or {} + parts = [] + + title = audio_meta.get("title") or raw.get("title", "") + if title: + parts.append(f"视频标题:{title}") + + uploader = raw.get("uploader", "") + if uploader: + parts.append(f"视频作者/UP主:{uploader}") + + desc = raw.get("description", "") + if desc: + parts.append(f"视频简介:{desc[:500]}") + + tags = raw.get("tags", []) + if tags and isinstance(tags, list): + parts.append(f"标签:{', '.join(str(t) for t in tags[:20])}") + + duration = audio_meta.get("duration", 0) + if duration: + m, s = divmod(int(duration), 60) + parts.append(f"视频时长:{m}分{s}秒") + + platform = audio_meta.get("platform", "") + if platform: + parts.append(f"平台:{platform}") + + url = raw.get("webpage_url", "") + if url: + parts.append(f"链接:{url}") + + if not parts: + return [] + + return [{ + "text": "\n".join(parts), + "metadata": {"source_type": "meta"}, + }] + + +class VectorStoreManager: + """基于 ChromaDB 的笔记向量存储管理器。""" + + def __init__(self): + os.makedirs(VECTOR_DB_DIR, exist_ok=True) + self._client = chromadb.PersistentClient( + path=VECTOR_DB_DIR, + settings=Settings(anonymized_telemetry=False), + ) + + def _collection_name(self, task_id: str) -> str: + """ChromaDB collection 名称:直接使用 task_id(UUID 格式合法)。""" + return task_id + + def index_task(self, task_id: str) -> None: + """读取笔记结果并建立向量索引。""" + result_path = os.path.join(NOTE_OUTPUT_DIR, f"{task_id}.json") + if not os.path.exists(result_path): + logger.warning(f"笔记文件不存在,跳过索引: {result_path}") + return + + with open(result_path, "r", encoding="utf-8") as f: + note_data = json.load(f) + + markdown = note_data.get("markdown", "") + transcript = note_data.get("transcript", {}) + segments = transcript.get("segments", []) + + audio_meta = note_data.get("audio_meta", {}) + + meta_chunks = _build_meta_chunk(audio_meta) + md_chunks = _chunk_markdown(markdown) + tr_chunks = _chunk_transcript(segments) + all_chunks = meta_chunks + md_chunks + tr_chunks + + if not all_chunks: + logger.warning(f"笔记内容为空,跳过索引: {task_id}") + return + + col_name = self._collection_name(task_id) + + # 删除旧 collection(幂等) + try: + self._client.delete_collection(col_name) + except Exception: + pass + + collection = self._client.create_collection( + name=col_name, + metadata={"hnsw:space": "cosine"}, + ) + + documents = [c["text"] for c in all_chunks] + metadatas = [c["metadata"] for c in all_chunks] + ids = [f"{task_id}_{i}" for i in range(len(all_chunks))] + + collection.add(documents=documents, metadatas=metadatas, ids=ids) + logger.info(f"向量索引完成: task_id={task_id}, chunks={len(all_chunks)}") + + def _parse_results(self, results: dict) -> list[dict]: + """将 ChromaDB query 结果转换为 chunk 列表。""" + chunks = [] + if not results or not results.get("documents") or not results["documents"][0]: + return chunks + for i in range(len(results["documents"][0])): + chunks.append({ + "text": results["documents"][0][i], + "metadata": results["metadatas"][0][i] if results["metadatas"] else {}, + "distance": results["distances"][0][i] if results["distances"] else None, + }) + return chunks + + def query(self, task_id: str, query_text: str, n_results: int = 6) -> list[dict]: + """ + 按固定配额从各来源检索:meta 1 条、markdown 2 条、transcript 3 条, + 确保三种来源都被召回。 + """ + col_name = self._collection_name(task_id) + try: + collection = self._client.get_collection(col_name) + except Exception: + logger.warning(f"Collection 不存在: {col_name}") + return [] + + all_chunks = [] + + # 每种来源的配额 + quotas = {"meta": 1, "markdown": 2, "transcript": 3} + + for source_type, quota in quotas.items(): + try: + results = collection.query( + query_texts=[query_text], + n_results=quota, + where={"source_type": source_type}, + ) + all_chunks.extend(self._parse_results(results)) + except Exception: + pass + + return all_chunks + + def list_indexed_task_ids(self) -> list[str]: + """返回所有已建立索引的 task_id。collection_name 与 task_id 一一对应。""" + try: + return [c.name for c in self._client.list_collections()] + except Exception as e: + logger.warning(f"列出 collection 失败: {e}") + return [] + + def query_across( + self, + query_text: str, + task_ids: Optional[list[str]] = None, + n_results_per_task: int = 3, + max_total: int = 12, + ) -> list[dict]: + """ + 跨多个笔记并行检索,按距离归并排序后截断。 + - task_ids=None: 全库(所有已索引的 task) + - 每条 chunk 额外带 task_id 字段,前端用来反查笔记 + """ + if task_ids is None: + task_ids = self.list_indexed_task_ids() + + if not task_ids: + return [] + + all_chunks: list[dict] = [] + for tid in task_ids: + try: + chunks = self.query(tid, query_text, n_results=n_results_per_task) + except Exception as e: + logger.warning(f"跨笔记检索单笔记失败 task_id={tid}: {e}") + continue + for ch in chunks: + ch["task_id"] = tid + all_chunks.extend(chunks) + + # 距离越小越相关;None 排到最后 + all_chunks.sort(key=lambda c: c.get("distance") if c.get("distance") is not None else float("inf")) + return all_chunks[:max_total] + + def delete_index(self, task_id: str) -> None: + """删除指定任务的向量索引。""" + col_name = self._collection_name(task_id) + try: + self._client.delete_collection(col_name) + logger.info(f"已删除向量索引: {task_id}") + except Exception: + pass + + def is_indexed(self, task_id: str) -> bool: + """检查指定任务是否已建立完整索引(含 meta 信息)。""" + col_name = self._collection_name(task_id) + try: + col = self._client.get_collection(col_name) + if col.count() == 0: + return False + # 检查是否包含 meta chunk,旧索引可能缺失 + meta = col.get(where={"source_type": "meta"}, limit=1) + return len(meta["ids"]) > 0 + except Exception: + return False diff --git a/backend/app/transcriber/__init__.py b/backend/app/transcriber/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backend/app/transcriber/base.py b/backend/app/transcriber/base.py new file mode 100644 index 0000000000000000000000000000000000000000..5cb27588b299225bbb8bab29631987f0da332316 --- /dev/null +++ b/backend/app/transcriber/base.py @@ -0,0 +1,23 @@ +from abc import ABC, abstractmethod + +from app.models.transcriber_model import TranscriptResult + + +class Transcriber(ABC): + @abstractmethod + def transcript(self,file_path:str)->TranscriptResult: + ''' + + :param file_path:音频路径 + :return: 返回一个 TranscriptResult 类 + ''' + pass + + def on_finish(self,video_path:str,result: TranscriptResult)->None: + ''' + 当音频转录完成时调用 + :param video_path: 视频路径 + :param result: 识别结果 + :return: + ''' + pass \ No newline at end of file diff --git a/backend/app/transcriber/bcut.py b/backend/app/transcriber/bcut.py new file mode 100644 index 0000000000000000000000000000000000000000..215dbfb66e823bd3407768cfd4b536c76ec73510 --- /dev/null +++ b/backend/app/transcriber/bcut.py @@ -0,0 +1,282 @@ +import json +import logging +import time +from typing import Optional, List, Dict, Union + +import requests + +from app.decorators.timeit import timeit +from app.models.transcriber_model import TranscriptSegment, TranscriptResult +from app.transcriber.base import Transcriber +from app.utils.logger import get_logger +from events import transcription_finished + +__version__ = "0.0.3" + +API_BASE_URL = "https://member.bilibili.com/x/bcut/rubick-interface" + +# 申请上传 +API_REQ_UPLOAD = API_BASE_URL + "/resource/create" + +# 提交上传 +API_COMMIT_UPLOAD = API_BASE_URL + "/resource/create/complete" + +# 创建任务 +API_CREATE_TASK = API_BASE_URL + "/task" + +# 查询结果 +API_QUERY_RESULT = API_BASE_URL + "/task/result" + +logger = get_logger(__name__) + + +def _bilibili_cookie() -> Optional[str]: + """读取「下载配置」里保存的 B 站 Cookie(没有则返回 None)。""" + try: + from app.services.cookie_manager import CookieConfigManager + return CookieConfigManager().get("bilibili") + except Exception: + return None + + +def _with_cookie_hint(msg: str) -> str: + """bcut 是 B 站的接口,未带 B 站 Cookie 时容易被风控拒绝(如「第三方服务异常」)。 + + 未配置 Cookie 时在报错后面追加可行动的提示;已配置则原样返回。 + """ + if "下载配置" in msg: # 已带提示,保持幂等 + return msg + if _bilibili_cookie(): + return msg + return ( + f"{msg}。bcut(必剪)转写走的是 B 站接口,未配置 B 站 Cookie 时容易被风控拒绝:" + "请在「设置 → 下载配置」中填写 B 站 Cookie 后重试," + "或在「设置 → 音频转写配置」中切换为本地转写引擎(fast-whisper / mlx-whisper)。" + ) + + +class BcutTranscriber(Transcriber): + """必剪 语音识别接口""" + headers = { + 'User-Agent': 'Bilibili/1.0.0 (https://www.bilibili.com)', + 'Content-Type': 'application/json' + } + + def __init__(self): + self.session = requests.Session() + # 带上「下载配置」里的 B 站 Cookie(如有),降低被 B 站风控拒绝的概率 + cookie = _bilibili_cookie() + if cookie: + self.headers = {**self.headers, 'Cookie': cookie} + self.task_id = None + self.__etags = [] + + self.__in_boss_key: Optional[str] = None + self.__resource_id: Optional[str] = None + self.__upload_id: Optional[str] = None + self.__upload_urls: List[str] = [] + self.__per_size: Optional[int] = None + self.__clips: Optional[int] = None + + self.__etags: List[str] = [] + self.__download_url: Optional[str] = None + self.task_id: Optional[str] = None + + def _load_file(self, file_path: str) -> bytes: + """读取文件内容""" + with open(file_path, 'rb') as f: + return f.read() + + def _upload(self, file_path: str) -> None: + """申请上传""" + file_binary = self._load_file(file_path) + if not file_binary: + raise ValueError("无法读取文件数据") + + payload = json.dumps({ + "type": 2, + "name": "audio.mp3", + "size": len(file_binary), + "ResourceFileType": "mp3", + "model_id": "8", + }) + + resp = self.session.post( + API_REQ_UPLOAD, + data=payload, + headers=self.headers + ) + resp.raise_for_status() + resp = resp.json() + resp_data = resp["data"] + + self.__in_boss_key = resp_data["in_boss_key"] + self.__resource_id = resp_data["resource_id"] + self.__upload_id = resp_data["upload_id"] + self.__upload_urls = resp_data["upload_urls"] + self.__per_size = resp_data["per_size"] + self.__clips = len(resp_data["upload_urls"]) + + logger.info( + f"申请上传成功, 总计大小{resp_data['size'] // 1024}KB, {self.__clips}分片, 分片大小{resp_data['per_size'] // 1024}KB: {self.__in_boss_key}" + ) + self.__upload_part(file_binary) + self.__commit_upload() + + def __upload_part(self, file_binary: bytes) -> None: + """上传音频数据""" + for clip in range(self.__clips): + start_range = clip * self.__per_size + end_range = min((clip + 1) * self.__per_size, len(file_binary)) + logger.info(f"开始上传分片{clip}: {start_range}-{end_range}") + resp = self.session.put( + self.__upload_urls[clip], + data=file_binary[start_range:end_range], + headers={'Content-Type': 'application/octet-stream'} + ) + resp.raise_for_status() + etag = resp.headers.get("Etag", "").strip('"') + self.__etags.append(etag) + logger.info(f"分片{clip}上传成功: {etag}") + + def __commit_upload(self) -> None: + """提交上传数据""" + data = json.dumps({ + "InBossKey": self.__in_boss_key, + "ResourceId": self.__resource_id, + "Etags": ",".join(self.__etags), + "UploadId": self.__upload_id, + "model_id": "8", + }) + resp = self.session.post( + API_COMMIT_UPLOAD, + data=data, + headers=self.headers + ) + resp.raise_for_status() + resp = resp.json() + print('Bili',resp) + if resp.get("code") != 0: + error_msg = f"上传提交失败: {resp.get('message', '未知错误')}" + logger.error(error_msg) + raise Exception(error_msg) + + self.__download_url = resp["data"]["download_url"] + logger.info(f"提交成功,下载链接: {self.__download_url}") + + def _create_task(self) -> str: + """开始创建转换任务""" + resp = self.session.post( + API_CREATE_TASK, json={"resource": self.__download_url, "model_id": "8"}, headers=self.headers + ) + resp.raise_for_status() + resp = resp.json() + if resp.get("code") != 0: + error_msg = f"创建任务失败: {resp.get('message', '未知错误')}" + logger.error(error_msg) + raise Exception(error_msg) + + self.task_id = resp["data"]["task_id"] + logger.info(f"任务已创建: {self.task_id}") + return self.task_id + + def _query_result(self) -> dict: + """查询转换结果""" + resp = self.session.get( + API_QUERY_RESULT, + params={"model_id": 7, "task_id": self.task_id}, + headers=self.headers + ) + resp.raise_for_status() + resp = resp.json() + if resp.get("code") != 0: + error_msg = f"查询结果失败: {resp.get('message', '未知错误')}" + logger.error(error_msg) + raise Exception(error_msg) + + return resp["data"] + + @timeit + def transcript(self, file_path: str) -> TranscriptResult: + """执行识别过程,符合 Transcriber 接口""" + try: + logger.info(f"开始处理文件: {file_path}") + + # 上传文件 + logger.info("正在上传文件...") + self._upload(file_path) + + # 创建任务 + logger.info("提交转录任务...") + self._create_task() + + # 轮询检查任务状态 + logger.info("等待转录结果...") + task_resp = None + max_retries = 500 + for i in range(max_retries): + task_resp = self._query_result() + + if task_resp["state"] == 4: # 完成状态 + break + elif task_resp["state"] == 3: # 失败状态 + error_msg = f"B站ASR任务失败,状态码: {task_resp['state']}" + logger.error(error_msg) + raise Exception(error_msg) + + # 每隔一段时间打印进度 + if i % 10 == 0: + logger.info(f"转录进行中... {i}/{max_retries}") + + time.sleep(1) + + if not task_resp or task_resp["state"] != 4: + error_msg = f"B站ASR任务未能完成,状态: {task_resp.get('state') if task_resp else 'Unknown'}" + logger.error(error_msg) + raise Exception(error_msg) + + # 解析结果 + logger.info("转录成功,处理结果...") + result_json = json.loads(task_resp["result"]) + + # 提取分段数据 + segments = [] + full_text = "" + + for u in result_json.get("utterances", []): + text = u.get("transcript", "").strip() + # B站ASR返回的时间戳是毫秒,需要转换为秒 + start_time = float(u.get("start_time", 0)) / 1000.0 + end_time = float(u.get("end_time", 0)) / 1000.0 + + full_text += text + " " + segments.append(TranscriptSegment( + start=start_time, + end=end_time, + text=text + )) + + # 创建结果对象 + result = TranscriptResult( + language=result_json.get("language", "zh"), + full_text=full_text.strip(), + segments=segments, + raw=result_json + ) + + # 触发完成事件 + # self.on_finish(file_path, result) + + return result + + except Exception as e: + logger.error(f"B站ASR处理失败: {str(e)}") + # 未配置 B 站 Cookie 时附加「去下载配置填 Cookie / 换本地转写引擎」的提示 + raise Exception(_with_cookie_hint(str(e))) from e + + def on_finish(self, video_path: str, result: TranscriptResult) -> None: + """转录完成的回调""" + logger.info(f"B站ASR转写完成: {video_path}") + transcription_finished.send({ + "file_path": video_path, + }) \ No newline at end of file diff --git a/backend/app/transcriber/funasr_transcriber.py b/backend/app/transcriber/funasr_transcriber.py new file mode 100644 index 0000000000000000000000000000000000000000..9c5ab9ad78330bac81f324db716791b1f5dc6f36 --- /dev/null +++ b/backend/app/transcriber/funasr_transcriber.py @@ -0,0 +1,163 @@ +import os +from typing import List + +from app.decorators.timeit import timeit +from app.models.transcriber_model import TranscriptSegment, TranscriptResult +from app.transcriber.base import Transcriber +from app.utils.logger import get_logger +from events import transcription_finished + +logger = get_logger(__name__) + + +class FunASRTranscriber(Transcriber): + """FunASR(阿里达摩院)本地语音识别。 + + 中文识别效果通常优于 Whisper,自带 VAD + 标点恢复。依赖 funasr + torch(较重, + 约 2GB),属可选引擎:未安装时不可用,由 transcriber_provider 的 FUNASR_AVAILABLE + 兜底并提示安装。模型首次使用时通过 modelscope 自动下载。 + + 不同模型族初始化方式不同,按名称分支: + - paraformer 系:model + vad_model(fsmn-vad) + punc_model(ct-punc),输出 sentence_info(句级时间戳) + - SenseVoice 系:model + vad_model(不带 punc,自带标点),generate 用 language/use_itn, + 文本经 rich_transcription_postprocess 清洗;无句级时间戳,退化为整段 + """ + + def __init__( + self, + model: str = None, + device: str = None, + ): + self.model_name = (model or os.getenv("FUNASR_MODEL", "paraformer-zh")).strip() + self.device = device or os.getenv("FUNASR_DEVICE") or None + + name = self.model_name.lower() + self.is_sensevoice = "sensevoice" in name + + from funasr import AutoModel # 懒加载:import funasr 会连带加载 torch + + if self.is_sensevoice: + # SenseVoice:用全名仓库 id;只配 VAD,不配 punc(其文本自带标点/反正则) + repo = self.model_name if "/" in self.model_name else "iic/SenseVoiceSmall" + logger.info(f"初始化 FunASR(SenseVoice):model={repo}, device={self.device or 'auto'}") + kwargs = dict( + model=repo, + vad_model="fsmn-vad", + vad_kwargs={"max_single_segment_time": 30000}, + disable_update=True, + ) + else: + # paraformer 等:vad + punc,输出句级时间戳 + logger.info( + f"初始化 FunASR:model={self.model_name}, vad=fsmn-vad, punc=ct-punc, " + f"device={self.device or 'auto'}" + ) + kwargs = dict( + model=self.model_name, + vad_model="fsmn-vad", + punc_model="ct-punc", + disable_update=True, + ) + if self.device: + kwargs["device"] = self.device + + self.model = AutoModel(**kwargs) + logger.info("FunASR 模型加载完成") + + def _vocab_mismatch_hint(self, err: Exception) -> str: + return ( + f"FunASR 模型「{self.model_name}」与当前 funasr 版本不兼容" + f"(模型词表与分词器不匹配:{err})。" + "英文/多语视频建议改用 SenseVoiceSmall(设置 → 音频转写配置 → FunASR 模型)," + "或切换到 Whisper 引擎。" + ) + + @timeit + def transcript(self, file_path: str) -> TranscriptResult: + try: + logger.info(f"FunASR 开始转写:{file_path}") + segments: List[TranscriptSegment] = [] + full_text = "" + + if self.is_sensevoice: + from funasr.utils.postprocess_utils import rich_transcription_postprocess + results = self.model.generate( + input=file_path, + cache={}, + language="auto", + use_itn=True, + batch_size_s=60, + merge_vad=True, + merge_length_s=15, + ) + # SenseVoice 文本含 <|emotion|><|event|> 等标记,用官方后处理清洗 + parts = [] + for item in results or []: + raw = item.get("text", "") + parts.append(rich_transcription_postprocess(raw) if raw else "") + full_text = "".join(parts).strip() + # SenseVoice 不产句级时间戳,退化为整段 + if full_text: + segments.append(TranscriptSegment(start=0.0, end=0.0, text=full_text)) + raw_obj = results + else: + # 句级时间戳只有离线 zh 系 paraformer 支持: + # - paraformer-en:无时间戳预测器,强开会解码越界(IndexError: piece id out of range) + # - paraformer-zh-streaming:流式模型同样无时间戳,强开会 KeyError: 'timestamp' + name_l = self.model_name.lower() + want_ts = "paraformer-zh" in name_l and "streaming" not in name_l + gen_kwargs = dict(input=file_path, batch_size_s=300) + if want_ts: + gen_kwargs["sentence_timestamp"] = True + try: + results = self.model.generate(**gen_kwargs) + except (IndexError, KeyError) as e: + if want_ts: + # 保险:个别 zh 变体可能不支持句级时间戳,降级为无时间戳重试一次 + logger.warning(f"{self.model_name} 句级时间戳解码失败({e}),降级为无时间戳重试") + gen_kwargs.pop("sentence_timestamp", None) + try: + results = self.model.generate(**gen_kwargs) + except (IndexError, KeyError) as e2: + raise RuntimeError(self._vocab_mismatch_hint(e2)) from e2 + elif isinstance(e, IndexError): + # 已无时间戳仍越界:模型包词表与 funasr 解码不匹配(如 paraformer-en + # 的 bpe.model 10000 词 vs tokens.json 10020 词),属上游兼容问题 + raise RuntimeError(self._vocab_mismatch_hint(e)) from e + else: + raise + item = results[0] if isinstance(results, list) and results else (results or {}) + full_text = (item.get("text") or "").strip() + for sent in item.get("sentence_info") or []: + text = (sent.get("text") or "").strip() + if not text: + continue + # FunASR 时间戳单位毫秒 + segments.append(TranscriptSegment( + start=float(sent.get("start", 0)) / 1000.0, + end=float(sent.get("end", 0)) / 1000.0, + text=text, + )) + if not segments and full_text: + segments.append(TranscriptSegment(start=0.0, end=0.0, text=full_text)) + raw_obj = item + + if not full_text and segments: + full_text = " ".join(s.text for s in segments) + + # 语言标记按模型名推断(影响下游 prompt 等);SenseVoice 多语统一标 zh 兜底 + lang = "en" if "-en" in self.model_name.lower() else "zh" + + return TranscriptResult( + language=lang, + full_text=full_text, + segments=segments, + raw=raw_obj, + ) + except Exception as e: + logger.error(f"FunASR 转写失败:{e}") + raise + + def on_finish(self, video_path: str, result: TranscriptResult) -> None: + logger.info(f"FunASR 转写完成:{video_path}") + transcription_finished.send({"file_path": video_path}) diff --git a/backend/app/transcriber/groq.py b/backend/app/transcriber/groq.py new file mode 100644 index 0000000000000000000000000000000000000000..033ffa6ed5ebe35d56ca94b8ff8af97b04a69a2d --- /dev/null +++ b/backend/app/transcriber/groq.py @@ -0,0 +1,70 @@ +from abc import ABC +import os + +from app.decorators.timeit import timeit +from app.models.transcriber_model import TranscriptResult, TranscriptSegment +from app.services.provider import ProviderService +from app.transcriber.base import Transcriber +from app.utils.openai_client import build_openai_client +import ffmpeg +import tempfile +from dotenv import load_dotenv +load_dotenv() +MAX_SIZE_MB = 18 +MAX_SIZE_BYTES = MAX_SIZE_MB * 1024 * 1024 +def compress_audio(input_path: str, target_bitrate='64k') -> str: + output_fd, output_path = tempfile.mkstemp(suffix=".mp3") # 临时输出文件 + os.close(output_fd) # 关闭文件描述符,ffmpeg 会用路径操作 + ffmpeg.input(input_path).output(output_path, audio_bitrate=target_bitrate).run(quiet=True, overwrite_output=True) + return output_path + +class GroqTranscriber(Transcriber, ABC): + + + @timeit + def transcript(self, file_path: str) -> TranscriptResult: + file_size = os.path.getsize(file_path) + if file_size > MAX_SIZE_BYTES: + print(f"文件超过 {MAX_SIZE_MB}MB,开始压缩(当前 {round(file_size / (1024 * 1024), 2)}MB)...") + file_path = compress_audio(file_path) + print(f"压缩完成,临时路径:{file_path}") + provider = ProviderService.get_provider_by_id('groq') + + if not provider: + raise Exception("Groq 供应商未配置,请配置以后使用。") + # build_openai_client 会校验 api_key 非空(空 key 会抛天书般的 + # `Illegal header value b'Bearer '`),并自动注入全局代理 + client = build_openai_client( + api_key=provider.get('api_key'), + base_url=provider.get('base_url'), + key_label="Groq 转写引擎的 API Key", + ) + filename = file_path + + with open(filename, "rb") as file: + transcription = client.audio.transcriptions.create( + file=(filename, file.read()), + model=os.getenv('GROQ_TRANSCRIBER_MODEL'), + response_format="verbose_json", + ) + print(transcription.text) + print(transcription) + segments = [] + full_text = "" + + for seg in transcription.segments: + text = seg.text.strip() + full_text += text + " " + segments.append(TranscriptSegment( + start=seg.start, + end=seg.end, + text=text + )) + + result = TranscriptResult( + language=transcription.language, + full_text=full_text.strip(), + segments=segments, + raw=transcription.to_dict() + ) + return result diff --git a/backend/app/transcriber/kuaishou.py b/backend/app/transcriber/kuaishou.py new file mode 100644 index 0000000000000000000000000000000000000000..e2846b1fa90f0d45349f414684f04c1626ac2ec7 --- /dev/null +++ b/backend/app/transcriber/kuaishou.py @@ -0,0 +1,115 @@ +import requests +import logging +import os +from typing import Union, List, Dict, Optional + +from app.decorators.timeit import timeit +from app.models.transcriber_model import TranscriptSegment, TranscriptResult +from app.transcriber.base import Transcriber +from app.utils.logger import get_logger +from events import transcription_finished + +logger = get_logger(__name__) + +class KuaishouTranscriber(Transcriber): + """快手语音识别实现""" + + API_URL = "https://ai.kuaishou.com/api/effects/subtitle_generate" + + def __init__(self): + pass + + def _load_file(self, file_path: str) -> bytes: + """读取文件内容""" + with open(file_path, 'rb') as f: + return f.read() + + def _submit(self, file_path: str) -> dict: + """提交识别请求""" + try: + file_binary = self._load_file(file_path) + + payload = { + "typeId": "1" + } + + # 使用文件名作为上传文件名 + file_name = os.path.basename(file_path) + files = [('file', (file_name, file_binary, 'audio/mpeg'))] + + logger.info(f"开始向快手API提交请求,文件: {file_name}") + response = requests.post(self.API_URL, data=payload, files=files, timeout=300) + response.raise_for_status() # 检查HTTP错误 + + result = response.json() + print('result',result) + # 检查快手API返回是否包含错误 + if "data" not in result or result.get("code", 0) != 0: + error_msg = f"快手API返回错误: {result.get('message', '未知错误')}" + logger.error(error_msg) + raise Exception(error_msg) + + return result + + except requests.exceptions.RequestException as e: + error_msg = f"快手ASR请求网络错误: {str(e)}" + logger.error(error_msg) + raise + except Exception as e: + error_msg = f"快手ASR请求处理错误: {str(e)}" + logger.error(error_msg) + raise + + @timeit + def transcript(self, file_path: str) -> TranscriptResult: + """执行转录过程,符合 Transcriber 接口""" + try: + logger.info(f"开始处理文件: {file_path}") + + # 提交请求并获取结果 + logger.info("向快手API提交识别请求...") + result_data = self._submit(file_path) + + logger.info("请求成功,处理结果...") + + # 提取分段数据 + segments = [] + full_text = "" + + # 解析快手API返回的文本段 + texts = result_data.get('data', {}).get('text', []) + for u in texts: + text = u.get('text', '').strip() + start_time = float(u.get('start_time', 0)) + end_time = float(u.get('end_time', 0)) + + full_text += text + " " + segments.append(TranscriptSegment( + start=start_time, + end=end_time, + text=text + )) + + # 创建结果对象 + result = TranscriptResult( + language="zh", # 快手API可能不返回语言信息,默认为中文 + full_text=full_text.strip(), + segments=segments, + raw=result_data + ) + + # 触发完成事件 + # self.on_finish(file_path, result) + + return result + + except Exception as e: + logger.error(f"快手ASR处理失败: {str(e)}") + raise + + def on_finish(self, video_path: str, result: TranscriptResult) -> None: + """转录完成的回调""" + logger.info(f"快手ASR转写完成: {video_path}") + transcription_finished.send({ + "file_path": video_path, + }) \ No newline at end of file diff --git a/backend/app/transcriber/mlx_whisper_transcriber.py b/backend/app/transcriber/mlx_whisper_transcriber.py new file mode 100644 index 0000000000000000000000000000000000000000..64fc94629352f4d2eb770c00c6ee51a703db20c8 --- /dev/null +++ b/backend/app/transcriber/mlx_whisper_transcriber.py @@ -0,0 +1,123 @@ +import mlx_whisper +from pathlib import Path +import os +import platform +from huggingface_hub import snapshot_download + +from app.decorators.timeit import timeit +from app.models.transcriber_model import TranscriptSegment, TranscriptResult +from app.transcriber.base import Transcriber +from app.utils.logger import get_logger +from app.utils.path_helper import get_model_dir +from events import transcription_finished + +logger = get_logger(__name__) + + +# mlx-community 上的 Whisper 仓库命名不统一:常规版本是 'whisper-{size}-mlx', +# turbo 例外没有 -mlx 后缀。直接拼 'mlx-community/whisper-{size}' 会 404。 +# 已用 https://huggingface.co/api/models?author=mlx-community&search=whisper 核对过。 +MLX_MODEL_MAP = { + "tiny": "mlx-community/whisper-tiny-mlx", + "base": "mlx-community/whisper-base-mlx", + "small": "mlx-community/whisper-small-mlx", + "medium": "mlx-community/whisper-medium-mlx", + "large-v1": "mlx-community/whisper-large-v1-mlx", + "large-v2": "mlx-community/whisper-large-v2-mlx", + "large-v3": "mlx-community/whisper-large-v3-mlx", + "large-v3-turbo": "mlx-community/whisper-large-v3-turbo", +} + + +def resolve_mlx_repo_id(model_size: str) -> str: + if model_size not in MLX_MODEL_MAP: + raise ValueError( + f"不支持的 MLX Whisper 模型大小: {model_size}。" + f"可选: {', '.join(MLX_MODEL_MAP.keys())}" + ) + return MLX_MODEL_MAP[model_size] + + +class MLXWhisperTranscriber(Transcriber): + def __init__( + self, + model_size: str = "base" + ): + # 检查平台 + if platform.system() != "Darwin": + raise RuntimeError("MLX Whisper 仅支持 Apple 平台") + + # 注意:不再校验 TRANSCRIBER_TYPE 环境变量——转写引擎早已改为 + # 「音频转写配置」页动态切换(transcriber_config_manager 持久化), + # 桌面端不会设置该环境变量,遗留校验会把 MLX 直接拦死。 + + self.model_size = model_size + self.model_name = resolve_mlx_repo_id(model_size) + self.model_path = None + + # 设置模型路径 + model_dir = get_model_dir("mlx-whisper") + self.model_path = os.path.join(model_dir, self.model_name) + # 用 config.json 而非目录存在作为「下载完成」的判据, + # 同 fast-whisper 的 model.bin:避免半成品目录把后续下载吞掉 + config_file = Path(self.model_path) / "config.json" + if not config_file.exists(): + if Path(self.model_path).exists(): + logger.warning( + f"MLX 模型目录 {self.model_path} 存在但 config.json 缺失(上次下载未完成),重新下载" + ) + else: + logger.info(f"模型 {self.model_name} 不存在,开始下载...") + snapshot_download( + self.model_name, + local_dir=self.model_path, + local_dir_use_symlinks=False, + ) + logger.info("模型下载完成") + + logger.info(f"初始化 MLX Whisper 转录器,模型:{self.model_name}") + + @timeit + def transcript(self, file_path: str) -> TranscriptResult: + try: + # 使用 MLX Whisper 进行转录。 + # 必须传 __init__ 里 snapshot_download 落盘的本地目录: + # 传 repo id 会让 mlx_whisper 每次先去 HuggingFace Hub 校验/下载, + # 网络不通时直接 LocalEntryNotFoundError,转写必然失败。 + result = mlx_whisper.transcribe( + file_path, + path_or_hf_repo=self.model_path + ) + + # 转换为标准格式 + segments = [] + full_text = "" + + for segment in result["segments"]: + text = segment["text"].strip() + full_text += text + " " + segments.append(TranscriptSegment( + start=segment["start"], + end=segment["end"], + text=text + )) + + transcript_result = TranscriptResult( + language=result.get("language", "unknown"), + full_text=full_text.strip(), + segments=segments, + raw=result + ) + + # self.on_finish(file_path, transcript_result) + return transcript_result + + except Exception as e: + logger.error(f"MLX Whisper 转写失败:{e}") + raise e + + def on_finish(self, video_path: str, result: TranscriptResult) -> None: + logger.info("MLX Whisper 转写完成") + transcription_finished.send({ + "file_path": video_path, + }) \ No newline at end of file diff --git a/backend/app/transcriber/transcriber_provider.py b/backend/app/transcriber/transcriber_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..d72b0c186d75eb14d5255144fdb1d2a58e6fe28a --- /dev/null +++ b/backend/app/transcriber/transcriber_provider.py @@ -0,0 +1,187 @@ +import importlib.util +import os +import platform +from enum import Enum + +from app.transcriber.groq import GroqTranscriber +from app.transcriber.whisper import WhisperTranscriber +from app.transcriber.bcut import BcutTranscriber +from app.transcriber.kuaishou import KuaishouTranscriber +from app.utils.logger import get_logger + +logger = get_logger(__name__) + +class TranscriberType(str, Enum): + FAST_WHISPER = "fast-whisper" + MLX_WHISPER = "mlx-whisper" + BCUT = "bcut" + KUAISHOU = "kuaishou" + GROQ = "groq" + FUNASR = "funasr" + +# FunASR 可选引擎:用 find_spec 探测是否安装,绝不在此 import(import funasr 会连带加载 +# torch,拖慢启动且桌面瘦身包没有 torch)。真正用到时才在 FunASRTranscriber 内部 import。 +# 桌面冻结包强制不可用:torch 无法在 PyInstaller 冻结运行时初始化(pybind 重复注册崩溃), +# 而且装进插件目录后会被 ctranslate2 的启动链路自动 import,直接把应用打挂。 +import sys as _sys +FUNASR_AVAILABLE = ( + not getattr(_sys, "frozen", False) + and importlib.util.find_spec("funasr") is not None +) +if FUNASR_AVAILABLE: + logger.info("FunASR 可用(已安装 funasr)") + +# 在 Apple 平台尝试导入 MLX Whisper(不再依赖环境变量,支持前端动态切换) +MLX_WHISPER_AVAILABLE = False +if platform.system() == "Darwin": + try: + from app.transcriber.mlx_whisper_transcriber import MLXWhisperTranscriber + MLX_WHISPER_AVAILABLE = True + logger.info("MLX Whisper 可用,已导入") + except ImportError: + logger.warning("MLX Whisper 导入失败,可能未安装 mlx_whisper") + +logger.info('初始化转录服务提供器') + +# 转录器单例缓存 +_transcribers = { + TranscriberType.FAST_WHISPER: None, + TranscriberType.MLX_WHISPER: None, + TranscriberType.BCUT: None, + TranscriberType.KUAISHOU: None, + TranscriberType.GROQ: None, + TranscriberType.FUNASR: None, +} + +# 公共实例初始化函数 +def _init_transcriber(key: TranscriberType, cls, *args, **kwargs): + if _transcribers[key] is None: + logger.info(f'创建 {cls.__name__} 实例: {key}') + try: + _transcribers[key] = cls(*args, **kwargs) + logger.info(f'{cls.__name__} 创建成功') + except Exception as e: + logger.error(f"{cls.__name__} 创建失败: {e}") + raise + return _transcribers[key] + +# 各类型获取方法 +def get_groq_transcriber(): + return _init_transcriber(TranscriberType.GROQ, GroqTranscriber) + +def get_whisper_transcriber(model_size="base", device="cuda"): + # size == "custom":使用用户在「音频转写配置」填的自定义模型(本地目录 / HF 仓库 id) + custom_path = None + if model_size == "custom": + from app.services.transcriber_config_manager import TranscriberConfigManager + custom_path = (TranscriberConfigManager().get_config().get("whisper_custom_model") or "").strip() + if not custom_path: + raise RuntimeError("已选择「自定义」Whisper 模型,但未填写模型路径或仓库 id;请到「音频转写配置」填写。") + + # 实例「变化即重建」:自定义时按路径比较,否则按档位比较 + target_key = custom_path if model_size == "custom" else model_size + inst = _transcribers[TranscriberType.FAST_WHISPER] + if inst is not None and getattr(inst, "model_size", None) != target_key: + logger.info(f"fast-whisper 模型变更 {getattr(inst, 'model_size', None)} -> {target_key},重建实例") + _transcribers[TranscriberType.FAST_WHISPER] = None + + if model_size == "custom": + return _init_transcriber(TranscriberType.FAST_WHISPER, WhisperTranscriber, model_path=custom_path, device=device) + return _init_transcriber(TranscriberType.FAST_WHISPER, WhisperTranscriber, model_size=model_size, device=device) + +def get_bcut_transcriber(): + return _init_transcriber(TranscriberType.BCUT, BcutTranscriber) + +def get_kuaishou_transcriber(): + return _init_transcriber(TranscriberType.KUAISHOU, KuaishouTranscriber) + +def get_funasr_transcriber(model: str = None): + if not FUNASR_AVAILABLE: + raise RuntimeError( + "FunASR 不可用:请先安装依赖(pip install funasr torch torchaudio)," + "安装后重启后端;或在「音频转写配置」页面切换到其他转写引擎。" + ) + # 模型名变更时重建实例(用户可在设置页填自定义 FunASR 模型) + inst = _transcribers[TranscriberType.FUNASR] + if inst is not None and getattr(inst, "model_name", None) != (model or "paraformer-zh"): + logger.info(f"FunASR 模型变更 {getattr(inst, 'model_name', None)} -> {model},重建实例") + _transcribers[TranscriberType.FUNASR] = None + # 延迟 import,避免模块加载阶段触发 torch + from app.transcriber.funasr_transcriber import FunASRTranscriber + return _init_transcriber(TranscriberType.FUNASR, FunASRTranscriber, model=model) + + +def get_mlx_whisper_transcriber(model_size="base"): + if not MLX_WHISPER_AVAILABLE: + logger.warning("MLX Whisper 不可用,请确保在 Apple 平台且已安装 mlx_whisper") + raise ImportError("MLX Whisper 不可用") + # 模型大小变更时重建实例:单例只按类型缓存,否则设置页切换 size 不生效 + inst = _transcribers[TranscriberType.MLX_WHISPER] + if inst is not None and getattr(inst, "model_size", None) != model_size: + logger.info(f"mlx-whisper 模型大小变更 {getattr(inst, 'model_size', None)} -> {model_size},重建实例") + _transcribers[TranscriberType.MLX_WHISPER] = None + return _init_transcriber(TranscriberType.MLX_WHISPER, MLXWhisperTranscriber, model_size=model_size) + +# 通用入口 +def get_transcriber(transcriber_type="fast-whisper", model_size=None, device="cuda"): + """ + 获取指定类型的转录器实例 + + 参数: + transcriber_type: 支持 "fast-whisper", "mlx-whisper", "bcut", "kuaishou", "groq" + model_size: 模型大小,适用于 whisper 类;不传时回退到环境变量 WHISPER_MODEL_SIZE + device: 设备类型(如 cuda / cpu),仅 whisper 使用 + + 返回: + 对应类型的转录器实例 + """ + logger.info(f'请求转录器类型: {transcriber_type}, 模型大小: {model_size or "(默认)"}') + + try: + transcriber_enum = TranscriberType(transcriber_type) + except ValueError: + logger.warning(f'未知转录器类型 "{transcriber_type}",默认使用 fast-whisper') + transcriber_enum = TranscriberType.FAST_WHISPER + + # 显式入参优先(来自「音频转写配置」页持久化的配置),环境变量只做未传参时的默认值。 + # 旧逻辑是环境变量覆盖入参,导致设置页选的模型大小永远被 .env 里的值顶掉。 + whisper_model_size = model_size or os.environ.get("WHISPER_MODEL_SIZE", "base") + + if transcriber_enum == TranscriberType.FAST_WHISPER: + return get_whisper_transcriber(whisper_model_size, device=device) + + elif transcriber_enum == TranscriberType.MLX_WHISPER: + if not MLX_WHISPER_AVAILABLE: + import sys + if getattr(sys, "frozen", False): + from app.utils.path_helper import get_plugin_packages_dir + hint = ( + f'请在终端执行:python3.11 -m pip install --target ' + f'"{get_plugin_packages_dir()}" mlx_whisper(需要 Python 3.11),' + "安装后重启应用生效;" + ) + else: + hint = "请安装 mlx_whisper 包(pip install mlx_whisper)后重启后端;" + raise RuntimeError( + f"MLX Whisper 不可用:需要 macOS(Apple Silicon)平台。{hint}" + "或在「音频转写配置」页面切换到其他转写引擎。" + ) + return get_mlx_whisper_transcriber(whisper_model_size) + + elif transcriber_enum == TranscriberType.BCUT: + return get_bcut_transcriber() + + elif transcriber_enum == TranscriberType.KUAISHOU: + return get_kuaishou_transcriber() + + elif transcriber_enum == TranscriberType.GROQ: + return get_groq_transcriber() + + elif transcriber_enum == TranscriberType.FUNASR: + from app.services.transcriber_config_manager import TranscriberConfigManager + funasr_model = TranscriberConfigManager().get_config().get("funasr_model") or "paraformer-zh" + return get_funasr_transcriber(model=funasr_model) + + # fallback + logger.warning(f'未识别转录器类型 "{transcriber_type}",使用 fast-whisper 作为默认') + return get_whisper_transcriber(whisper_model_size, device=device) diff --git a/backend/app/transcriber/whisper.py b/backend/app/transcriber/whisper.py new file mode 100644 index 0000000000000000000000000000000000000000..d9bd8ac1bb1dc7bef91f5e4ca1c55176772d69a4 --- /dev/null +++ b/backend/app/transcriber/whisper.py @@ -0,0 +1,150 @@ +from faster_whisper import WhisperModel + +from app.decorators.timeit import timeit +from app.models.transcriber_model import TranscriptSegment, TranscriptResult +from app.transcriber.base import Transcriber +from app.utils.env_checker import is_cuda_available, is_torch_installed +from app.utils.logger import get_logger +from app.utils.path_helper import get_model_dir + +from events import transcription_finished +from pathlib import Path +import os +import shutil + + +''' + Size of the model to use (tiny, tiny.en, base, base.en, small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1, large-v2, large-v3, large, distil-large-v2, distil-large-v3, large-v3-turbo, or turbo +''' +logger=get_logger(__name__) + +# 历史遗留:之前用 modelscope 下载到自定义目录然后把路径传给 WhisperModel。 +# 但 faster-whisper 1.1.1 的 download_model(utils.py:76)逻辑是: +# 只要 size_or_id 里含 "/" 就当 HF repo_id 处理,没有「本地目录直接返回」分支。 +# 我们传 /app/models/whisper/whisper-tiny 进去 → 被当成不存在的 HF repo → +# 在线请求失败 → fallback local_files_only=True → HF cache 找不到(因为是 +# modelscope 目录布局不是 HF)→ LocalEntryNotFoundError,误导说"离线模式"。 +# 解法:彻底让 faster-whisper 自己处理下载——传 size name,配 download_root +# 作为 HF cache 根目录,HF_ENDPOINT 已经在 Dockerfile 里指到 hf-mirror.com, +# 国内能用。删掉 modelscope 那一套,避免布局不匹配。 +class WhisperTranscriber(Transcriber): + def __init__( + self, + model_size: str = "base", + device: str = 'cpu', + compute_type: str = None, + cpu_threads: int = 1, + model_path: str = None, + ): + if device == 'cpu' or device is None: + self.device = 'cpu' + else: + self.device = "cuda" if self.is_cuda() else "cpu" + if device == 'cuda' and self.device == 'cpu': + print('没有 cuda 使用 cpu进行计算') + + self.compute_type = compute_type or ("float16" if self.device == "cuda" else "int8") + + # model_path 非空:用户自定义模型(本地 CTranslate2 目录 或 HF 仓库 id)。 + # faster-whisper 的 WhisperModel 对存在的本地目录直接加载;含 "/" 的字符串当 HF repo。 + self.is_custom = bool(model_path) + # 单例「变化即重建」按 self.model_size 比较,自定义时用路径本身作为键 + self.model_size = model_path if self.is_custom else model_size + self._source = model_path if self.is_custom else model_size + + model_dir = get_model_dir("whisper") + try: + self.model = self._build_model(self._source, model_dir) + except Exception as e: + if self.is_custom: + # 自定义模型不动 cache(命名规则不适用),直接抛出可读错误 + logger.error(f"加载自定义 whisper 模型失败({self._source}):{e}") + raise + # 自愈:损坏 / 截断 / 半成品 cache → 删掉对应 HF cache 重下一次 + logger.warning(f"加载 whisper-{model_size} 失败:{e};清理 cache 后重新下载") + self._purge_cache(model_dir, model_size) + self.model = self._build_model(self._source, model_dir) + + def _build_model(self, source: str, model_dir: str) -> WhisperModel: + return WhisperModel( + model_size_or_path=source, # 预设档名 / 自定义本地目录 / HF 仓库 id + device=self.device, + compute_type=self.compute_type, + download_root=model_dir, + ) + + @staticmethod + def _purge_cache(model_dir: str, model_size: str) -> None: + """删掉 HF cache 里这个 size 对应的 snapshot 目录,强制下次重新下载。 + + HF cache 布局:/models--Systran--faster-whisper-{size}/ + 没找到也不报错——可能用户改了 endpoint 或者 cache 布局变了。 + """ + candidates = [ + Path(model_dir) / f"models--Systran--faster-whisper-{model_size}", + Path(model_dir) / f"whisper-{model_size}", # 历史 modelscope 目录,顺手清掉 + ] + for path in candidates: + if path.exists(): + logger.info(f"清理损坏 cache: {path}") + shutil.rmtree(path, ignore_errors=True) + @staticmethod + def is_torch_installed() -> bool: + try: + import torch + return True + except ImportError: + return False + + @staticmethod + def is_cuda() -> bool: + try: + if is_cuda_available(): + print(" CUDA 可用,使用 GPU") + return True + elif is_torch_installed(): + print(" 只装了 torch,但没有 CUDA,用 CPU") + return False + else: + print(" 还没有安装 torch,请先安装") + return False + + except ImportError: + return False + + @timeit + def transcript(self, file_path: str) -> TranscriptResult: + try: + + segments_raw, info = self.model.transcribe(file_path) + + segments = [] + full_text = "" + + for seg in segments_raw: + text = seg.text.strip() + full_text += text + " " + segments.append(TranscriptSegment( + start=seg.start, + end=seg.end, + text=text + )) + + result= TranscriptResult( + language=info.language, + full_text=full_text.strip(), + segments=segments, + raw=info + ) + # self.on_finish(file_path, result) + return result + except Exception as e: + print(f"转写失败:{e}") + + + def on_finish(self,video_path:str,result: TranscriptResult)->None: + print("转写完成") + transcription_finished.send({ + "file_path": video_path, + }) + diff --git a/backend/app/utils/cover_helper.py b/backend/app/utils/cover_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..83954971079fab4acfda6e8c49fdf39e61e52ea0 --- /dev/null +++ b/backend/app/utils/cover_helper.py @@ -0,0 +1,106 @@ +"""封面图本地化工具。 + +为什么需要: +- B 站封面是 ``http://`` 直链,桌面端(Tauri WebView 的安全上下文)会按 + mixed content 直接拦截,左侧列表/阅读区 banner 只剩渐变兜底; +- 抖音 / 快手封面是带签名的限时 CDN URL,过期后 404,代理也救不回来。 + +所以在笔记生成阶段就把封面下载到 ``/static/covers/`` 本地缓存, +结果 JSON 里存稳定的相对路径,前端到哪个端(web / 桌面 / 扩展)都能渲染。 +""" +import hashlib +import os +from urllib.parse import urlparse + +import requests + +from app.utils.path_helper import get_runtime_dir +from app.utils.logger import get_logger + +logger = get_logger(__name__) + +DEFAULT_UA = ( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 " + "(KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" +) + +# Content-Type → 扩展名;遇到不认识的类型统一按 jpg 存 +_EXT_BY_TYPE = { + "image/jpeg": ".jpg", + "image/png": ".png", + "image/webp": ".webp", + "image/gif": ".gif", + "image/avif": ".avif", +} + + +def pick_referer(image_url: str) -> str: + """根据图片 URL 的 host 选择合适的 Referer。 + + 各平台(B 站 / YouTube / 抖音 / 小红书)的 CDN 都做了 Referer 校验: + 用错了 Referer 会被 403。 + """ + try: + host = (urlparse(image_url).hostname or "").lower() + except Exception: + return "" + + if any(s in host for s in ("bilibili", "hdslb", "biliimg")): + return "https://www.bilibili.com/" + if any(s in host for s in ("youtube", "ytimg", "ggpht", "googlevideo")): + return "https://www.youtube.com/" + if any(s in host for s in ("xiaohongshu", "xhscdn", "xhslink")): + return "https://www.xiaohongshu.com/" + if any(s in host for s in ("douyin", "douyinpic", "douyinvod", "iesdouyin", "amemv")): + return "https://www.douyin.com/" + if "kuaishou" in host or "yximgs" in host: + return "https://www.kuaishou.com/" + # 其它平台不发 Referer,让服务器决定。 + return "" + + +def _covers_dir() -> str: + path = os.path.join(get_runtime_dir("static"), "covers") + os.makedirs(path, exist_ok=True) + return path + + +def localize_cover(cover_url: str, platform: str = "") -> str | None: + """把远程封面下载到 ``static/covers/``,返回 ``/static/covers/xxx`` 相对路径。 + + 幂等:同一 URL 落到同一文件名(URL md5),已存在直接复用。 + 任何失败都返回 None,调用方保留原始 URL,不影响笔记生成主流程。 + """ + if not cover_url or not str(cover_url).startswith(("http://", "https://")): + return None + + digest = hashlib.md5(cover_url.encode("utf-8")).hexdigest()[:20] + covers_dir = _covers_dir() + + # 已缓存过(任意扩展名)直接复用 + for ext in set(_EXT_BY_TYPE.values()): + cached = os.path.join(covers_dir, f"{digest}{ext}") + if os.path.exists(cached): + return f"/static/covers/{digest}{ext}" + + headers = {"User-Agent": DEFAULT_UA} + referer = pick_referer(cover_url) + if referer: + headers["Referer"] = referer + + try: + resp = requests.get(cover_url, headers=headers, timeout=10) + resp.raise_for_status() + content_type = (resp.headers.get("Content-Type") or "").split(";")[0].strip().lower() + if content_type and not content_type.startswith("image/"): + logger.warning(f"封面本地化跳过:返回的不是图片 Content-Type={content_type} url={cover_url}") + return None + ext = _EXT_BY_TYPE.get(content_type, ".jpg") + filename = f"{digest}{ext}" + with open(os.path.join(covers_dir, filename), "wb") as f: + f.write(resp.content) + logger.info(f"封面已本地化: {cover_url} -> /static/covers/{filename}") + return f"/static/covers/{filename}" + except Exception as e: + logger.warning(f"封面本地化失败(保留原始 URL): {e} url={cover_url}") + return None diff --git a/backend/app/utils/env_checker.py b/backend/app/utils/env_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..667cfd40b4d6343f38cfdf2f98bd599e3ca66a1f --- /dev/null +++ b/backend/app/utils/env_checker.py @@ -0,0 +1,12 @@ +def is_cuda_available() -> bool: + try: + import torch + return torch.cuda.is_available() + except ImportError: + return False +def is_torch_installed() -> bool: + try: + import torch + return True + except ImportError: + return False diff --git a/backend/app/utils/export.py b/backend/app/utils/export.py new file mode 100644 index 0000000000000000000000000000000000000000..aa22c62027aa9b4915d735baeaa1e0ef513654df --- /dev/null +++ b/backend/app/utils/export.py @@ -0,0 +1,450 @@ +import os +import re +from urllib.parse import quote +from markdown_pdf import MarkdownPdf, Section +from dotenv import load_dotenv + +load_dotenv() + +# 项目根路径(无论你在哪里运行) +BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# 从 .env 获取 DATA_DIR,相对于 BASE_DIR 解析 +DATA_DIR_NAME = os.getenv("DATA_DIR", "data") +DATA_DIR = os.path.join(BASE_DIR, DATA_DIR_NAME) +SAVE_PATH = os.path.join(DATA_DIR, "note_output") +IMAGE_BASE_URL = os.getenv("IMAGE_BASE_URL") +STATIC_BASE = os.path.join(BASE_DIR, IMAGE_BASE_URL) + + +class ExportUtils: + def __init__(self, **kwargs): + # 确认SAVE_PATH存在 + print(f"保存路径: {SAVE_PATH}") + print(f"静态文件路径: {STATIC_BASE}") + if not os.path.exists(SAVE_PATH): + os.makedirs(SAVE_PATH) + + def _embed_image_as_base64(self, img_path: str) -> str: + """ + 将图片转换为 base64 格式嵌入 + """ + import base64 + import mimetypes + + try: + # 获取 MIME 类型 + mime_type, _ = mimetypes.guess_type(img_path) + if not mime_type: + # 根据扩展名推断 + ext = os.path.splitext(img_path)[1].lower() + mime_map = { + '.png': 'image/png', + '.jpg': 'image/jpeg', + '.jpeg': 'image/jpeg', + '.gif': 'image/gif', + '.bmp': 'image/bmp', + '.webp': 'image/webp', + '.svg': 'image/svg+xml' + } + mime_type = mime_map.get(ext, 'image/png') + + # 读取图片文件并转换为 base64 + with open(img_path, 'rb') as f: + img_data = f.read() + + base64_data = base64.b64encode(img_data).decode('utf-8') + return f"data:{mime_type};base64,{base64_data}" + + except Exception as e: + print(f"图片 base64 编码失败 {img_path}: {str(e)}") + return None + + def _get_normalized_path(self, path: str) -> str: + """ + 获取规范化的绝对路径 + """ + return os.path.normpath(os.path.abspath(path)) + + def _replace_static_paths_with_absolute(self, content: str) -> str: + """ + 将 Markdown 中的图片路径替换为 base64 内嵌格式 + 这样可以确保图片在 PDF 中正确显示 + """ + + def repl(match): + # 捕获 alt 文本和路径 + alt_text = match.group(1) if match.group(1) else "" + img_path = match.group(2).strip() + + print(f"处理图片路径: {img_path}") + + # 处理 /static/ 开头的路径 + if img_path.startswith("/static/"): + # 构建绝对路径 + relative_path = img_path.lstrip("/") # 移除开头的 / + abs_path = os.path.join(BASE_DIR, relative_path) + abs_path = self._get_normalized_path(abs_path) + + # 检查文件是否存在并转换为 base64 + if os.path.exists(abs_path): + base64_uri = self._embed_image_as_base64(abs_path) + if base64_uri: + print(f"图片转换为 base64 成功: {img_path}") + return f"![{alt_text}]({base64_uri})" + else: + print(f"图片 base64 转换失败: {abs_path}") + return f"![{alt_text}](图片转换失败: {img_path})" + else: + print(f"警告:图片文件不存在 {abs_path}") + return f"![{alt_text}](图片不存在: {img_path})" + + # 处理相对路径(相对于 STATIC_BASE) + elif not img_path.startswith(('http://', 'https://', 'data:')): + # 尝试多个可能的路径 + possible_paths = [ + os.path.join(STATIC_BASE, img_path), + os.path.abspath(img_path), + os.path.join(BASE_DIR, img_path) + ] + + for abs_path in possible_paths: + abs_path = self._get_normalized_path(abs_path) + if os.path.exists(abs_path): + base64_uri = self._embed_image_as_base64(abs_path) + if base64_uri: + print(f"相对路径图片转换为 base64 成功: {img_path}") + return f"![{alt_text}]({base64_uri})" + break + + print(f"警告:图片文件未找到 {img_path}") + return f"![{alt_text}](图片未找到: {img_path})" + + # HTTP/HTTPS 和 data: 路径保持不变 + elif img_path.startswith(('http://', 'https://', 'data:')): + print(f"网络图片或 data URI 保持不变: {img_path[:50]}...") + return match.group(0) + + # 其他情况保持不变 + return match.group(0) + + # 使用更精确的正则表达式匹配图片语法 + # 匹配 ![alt text](path) 格式 + pattern = r'!\[([^\]]*)\]\(([^)]+)\)' + result = re.sub(pattern, repl, content) + + print("图片路径处理完成") + return result + + def _to_html(self, content: str, title: str): + """ + 将 Markdown 内容转换为 HTML。 + 用 markdown-it-py 做转换,套一个简洁中文友好的 CSS。 + """ + try: + from markdown_it import MarkdownIt + + md = MarkdownIt("commonmark", {"html": True, "linkify": True, "typographer": True}).enable("table").enable("strikethrough") + html_body = md.render(content) + + css = """ + body { + font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", "PingFang SC", + "Hiragino Sans GB", "Microsoft YaHei", "Source Han Sans CN", + Roboto, Helvetica, Arial, sans-serif; + line-height: 1.7; + color: #24292f; + max-width: 860px; + margin: 32px auto; + padding: 0 24px 64px; + } + h1, h2, h3, h4 { line-height: 1.3; margin-top: 1.4em; margin-bottom: 0.6em; font-weight: 600; } + h1 { font-size: 2em; border-bottom: 1px solid #eaecef; padding-bottom: 0.3em; } + h2 { font-size: 1.5em; border-bottom: 1px solid #eaecef; padding-bottom: 0.3em; } + h3 { font-size: 1.25em; } + p { margin: 0.8em 0; } + a { color: #0969da; text-decoration: none; } + a:hover { text-decoration: underline; } + code { background: rgba(175,184,193,0.2); padding: 0.18em 0.4em; border-radius: 4px; + font-family: "SFMono-Regular", Consolas, "Liberation Mono", Menlo, monospace; font-size: 0.92em; } + pre { background: #f6f8fa; padding: 14px 16px; border-radius: 6px; overflow-x: auto; } + pre code { background: none; padding: 0; } + blockquote { border-left: 4px solid #d0d7de; color: #57606a; padding: 0 1em; margin: 1em 0; } + img { max-width: 100%; border-radius: 6px; } + table { border-collapse: collapse; margin: 1em 0; } + table th, table td { border: 1px solid #d0d7de; padding: 6px 13px; } + table th { background: #f6f8fa; font-weight: 600; } + ul, ol { padding-left: 1.6em; } + hr { border: 0; border-top: 1px solid #eaecef; margin: 2em 0; } + """ + + html_doc = f""" + + + + +{title} + + + +

{title}

+{html_body} + + +""" + save_path = os.path.join(SAVE_PATH, f"{title}.html") + with open(save_path, "w", encoding="utf-8") as f: + f.write(html_doc) + print(f"HTML 导出成功: {save_path}") + return save_path + except Exception as e: + print(f"HTML 导出失败: {e}") + raise + + def _to_word(self, content: str, title: str): + """ + 将 Markdown 内容转换为 Word (.docx)。 + 策略:markdown → HTML(已有 _to_html 的 md),然后用 BeautifulSoup 遍历元素塞进 python-docx。 + 覆盖标题(h1-h4)、段落、有/无序列表、代码块、链接(保留文本)、图片(base64 已被前置处理, + 先解出来再插入),其它当作普通段落处理。 + """ + try: + from io import BytesIO + import base64 as _b64 + from markdown_it import MarkdownIt + from bs4 import BeautifulSoup + from docx import Document + from docx.shared import Pt, Inches + + md = MarkdownIt("commonmark", {"html": True}).enable("table").enable("strikethrough") + html = md.render(content) + soup = BeautifulSoup(html, "html.parser") + + doc = Document() + # 默认正文字体 + style = doc.styles["Normal"] + style.font.name = "Microsoft YaHei" + style.font.size = Pt(11) + + doc.add_heading(title, level=0) + + heading_levels = {"h1": 1, "h2": 2, "h3": 3, "h4": 4, "h5": 5, "h6": 6} + + def _add_image(src: str): + try: + if src.startswith("data:"): + # data:image/png;base64,XXXX + header, _, b64 = src.partition(",") + img_bytes = _b64.b64decode(b64) + doc.add_picture(BytesIO(img_bytes), width=Inches(5.5)) + elif src.startswith(("http://", "https://")): + doc.add_paragraph(f"[网络图片] {src}") + elif os.path.exists(src): + doc.add_picture(src, width=Inches(5.5)) + else: + doc.add_paragraph(f"[图片缺失] {src}") + except Exception as e: + doc.add_paragraph(f"[图片插入失败] {src}: {e}") + + def _add_paragraph_with_inline(el): + """段落级元素:保留文本,链接/strong/em 都按纯文本处理(python-docx 内联格式比较繁,先够用)。""" + text = el.get_text("", strip=False).strip() + if text: + doc.add_paragraph(text) + # 段落内嵌的图片单独提出来 + for img in el.find_all("img"): + _add_image(img.get("src", "")) + + for child in soup.children: + name = getattr(child, "name", None) + if name is None: + continue + if name in heading_levels: + doc.add_heading(child.get_text(strip=True), level=heading_levels[name]) + elif name == "p": + _add_paragraph_with_inline(child) + elif name in ("ul", "ol"): + style_name = "List Bullet" if name == "ul" else "List Number" + for li in child.find_all("li", recursive=False): + doc.add_paragraph(li.get_text(strip=True), style=style_name) + elif name == "pre": + code_text = child.get_text("", strip=False) + p = doc.add_paragraph() + run = p.add_run(code_text) + run.font.name = "Consolas" + run.font.size = Pt(10) + elif name == "blockquote": + p = doc.add_paragraph(child.get_text(strip=True)) + p.paragraph_format.left_indent = Inches(0.3) + elif name == "img": + _add_image(child.get("src", "")) + elif name == "hr": + doc.add_paragraph("———————————————") + elif name == "table": + rows = child.find_all("tr") + if not rows: + continue + cols = max(len(r.find_all(["td", "th"])) for r in rows) + table = doc.add_table(rows=len(rows), cols=cols) + table.style = "Light Grid Accent 1" + for r_idx, tr in enumerate(rows): + cells = tr.find_all(["td", "th"]) + for c_idx, td in enumerate(cells): + table.cell(r_idx, c_idx).text = td.get_text(strip=True) + else: + # 兜底:当成段落 + text = child.get_text("", strip=False).strip() + if text: + doc.add_paragraph(text) + + save_path = os.path.join(SAVE_PATH, f"{title}.docx") + doc.save(save_path) + print(f"Word 导出成功: {save_path}") + return save_path + except Exception as e: + print(f"Word 导出失败: {e}") + raise + + def _to_pdf(self, content: str, title: str): + """ + 将 Markdown 内容转换为 PDF + """ + try: + # 创建 PDF 对象,启用优化 + pdf = MarkdownPdf( + optimize=True, + # 添加一些可能有助于图片显示的配置 + # toc=False, + # paper_size='A4', + # margin=dict(top='1cm', bottom='1cm', left='1cm', right='1cm') + ) + + # 添加内容段落 + pdf.add_section(Section(content)) + + # 保存 PDF + save_path = os.path.join(SAVE_PATH, f"{title}.pdf") + pdf.save(save_path) + + print(f"PDF 导出成功: {save_path}") + return save_path + + except Exception as e: + print(f"PDF 导出失败: {str(e)}") + print("尝试使用基本配置...") + try: + # 尝试最基本的配置 + pdf = MarkdownPdf() + pdf.add_section(Section(content)) + save_path = os.path.join(SAVE_PATH, f"{title}.pdf") + pdf.save(save_path) + print(f"基本配置 PDF 导出成功: {save_path}") + return save_path + except Exception as e2: + print(f"基本配置也失败: {str(e2)}") + raise e2 + + def export(self, output_format: str, title: str, content: str) -> str: + """ + 导出内容为指定格式 + 支持格式:pdf, html, word/docx, image/png + """ + content = content.strip() + + # 处理图片路径 + print("开始处理图片路径...") + content = self._replace_static_paths_with_absolute(content) + + output_format = output_format.lower() + + try: + if output_format == "pdf": + save_path = self._to_pdf(content, title) + elif output_format == "html": + save_path = self._to_html(content, title) + elif output_format in ["word", "docx"]: + save_path = self._to_word(content, title) + else: + supported_formats = ["pdf", "html", "word/docx"] + raise ValueError(f"不支持的导出格式: {output_format}. 支持的格式: {', '.join(supported_formats)}") + + print(f"导出完成: {save_path}") + return save_path + + except Exception as e: + print(f"导出失败: {str(e)}") + raise e + + def get_supported_formats(self): + """ + 返回支持的导出格式列表 + """ + return { + "pdf": "PDF 文档", + "html": "HTML 网页", + "word": "Word 文档 (.docx)", + "docx": "Word 文档 (.docx)", + } + def debug_paths(self): + """ + 调试方法:打印重要路径信息 + """ + print("=== 路径调试信息 ===") + print(f"BASE_DIR: {BASE_DIR}") + print(f"DATA_DIR: {DATA_DIR}") + print(f"SAVE_PATH: {SAVE_PATH}") + print(f"STATIC_BASE: {STATIC_BASE}") + print(f"IMAGE_BASE_URL: {IMAGE_BASE_URL}") + print("==================") + +if __name__ == '__main__': + + ExportUtils().export("pdf",title='测试',content='''# 视频笔记:Facial Recognition Forces My Coworkers to Do Their Dishes + +## 简介 +该视频展示了团队如何利用面部识别技术来监控和激励同事清洗餐具。通过结合硬件和软件,团队开发了一个“Dish Watcher”系统,旨在识别并提醒那些未清洁餐具的人。 + +## 背景 +- 团队面临的问题是同事们不愿意清洗餐具。 +- 为解决这一问题,团队决定在不告知的情况下使用技术来监控厨房区域。 + +## 实验设计 +1\. **设备安装** +- 使用Raspberry Pi和隐藏摄像头来捕捉厨房水槽的活动。 +- 摄像头只在有人在水槽附近活动时录制,以节省存储空间。 + +2\. **软件开发** +- 使用Cursor AI和Meta的项目来分析视频。 +- 系统能识别人员特征如发型、服装,并将结果发送到Discord服务器以提醒团队。 + +3\. **面部识别** +- 通过视频流实时分析来判断是否有人留下了脏餐具。 +- 系统能识别并记录下未清洗餐具的人的详细特征。 + +![](/static/screenshots/screenshot_000_a61be29d-06ae-42ee-ac38-2d0b1db394f3.jpg)* 展示了堆积的脏餐具,问题的严重性可见一斑。 + +## 实验过程 +- 系统成功捕获了少数“罪犯”,并通过Discord进行了通知。 +- 计划将摄像头隐藏在厨房的画作后,使其更加隐蔽。 + +![](/static/screenshots/screenshot_001_e9d1c7ad-509e-4c7d-a718-a09193e97724.jpg)* SAM 介绍了项目的背景。 + +## 结果 +- 实验初期,系统有效地识别了不清洗餐具的同事。 +- 由于摄像头的存在,同事们开始自觉清洗餐具,长时间未发现新的“罪犯”。 + +## 思考与改进 +- 团队意识到仅仅通过惩罚来改变行为可能效果有限,考虑奖励来激励清洗餐具。 +- 系统将改进为奖励机制,记录并表扬那些清洗餐具的人。 + +## 总结 +这次实验展示了技术在工作场所行为管理中的应用潜力。通过实验,团队不仅解决了餐具清洗的问题,还对如何更有效地激励员工有了更深的认识。 + +![](/static/screenshots/screenshot_002_f1ca0c20-c657-417f-be78-7958bf0e7a4b.jpg)* 展示了系统对某位同事洗碗的实时面部识别。 + +## 结论 +- 应用技术可以有效改善工作环境中的小问题。 +- 积极的激励比惩罚更能驱动行为改变。 + +通过这次实验,团队不仅解决了餐具堆积的问题,还为未来更复杂的行为管理系统奠定了基础。 ''',) + diff --git a/backend/app/utils/logger.py b/backend/app/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..72b7e6e9f7f8078180d51457550f061954f059de --- /dev/null +++ b/backend/app/utils/logger.py @@ -0,0 +1,32 @@ +import logging +import sys +from pathlib import Path + +# 日志目录 +LOG_DIR = Path("logs") +LOG_DIR.mkdir(exist_ok=True) + +# 日志格式 +formatter = logging.Formatter( + fmt="%(asctime)s [%(levelname)s] %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S" +) + +# 控制台输出 +console_handler = logging.StreamHandler(sys.stdout) +console_handler.setFormatter(formatter) + +# 文件输出 +file_handler = logging.FileHandler(LOG_DIR / "app.log", encoding="utf-8") +file_handler.setFormatter(formatter) + +# 获取日志器 + +def get_logger(name: str) -> logging.Logger: + logger = logging.getLogger(name) + if not logger.handlers: + logger.setLevel(logging.INFO) + logger.addHandler(console_handler) + logger.addHandler(file_handler) + logger.propagate = False + return logger diff --git a/backend/app/utils/note_helper.py b/backend/app/utils/note_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..228b9cf877e072a1b9fa6c79bff31a574b77e90c --- /dev/null +++ b/backend/app/utils/note_helper.py @@ -0,0 +1,118 @@ +import re + + +def prepend_source_link(markdown: str | None, source_url: str) -> str | None: + """ + 在笔记开头添加来源链接;若首个非空行已包含来源链接,则更新该行并避免重复。 + """ + if markdown is None: + return None + + source = (source_url or "").strip() + if not source: + return markdown + + header = f"> 来源链接:{source}" + lines = markdown.splitlines() + first_non_empty_idx = None + for idx, line in enumerate(lines): + if line.strip(): + first_non_empty_idx = idx + break + + if first_non_empty_idx is not None: + first_line = lines[first_non_empty_idx].strip() + if first_line.startswith("> 来源链接:") or first_line.startswith("来源链接:"): + lines[first_non_empty_idx] = header + return "\n".join(lines) + + if markdown.strip(): + return f"{header}\n\n{markdown}" + return header + + +def normalize_toc(markdown: str | None) -> str | None: + """规范化「## 目录」区块:剥掉目录条目里误带的 `#`/`##` 标题标记。 + + LLM 有时把章节标题的 `##` 标记原样抄进目录列表(`- ## 1. xxx`), + 渲染出来和正文标题一样大。这里只做一件事:把目录区块内所有列表条目 + (含缩进子项)开头的标题标记剥掉——嵌套子项、加粗、链接等都允许, + 原样保留。没有目录区块时原样返回。 + """ + if not markdown: + return markdown + + lines = markdown.split('\n') + out = [] + in_toc = False + for line in lines: + stripped = line.strip() + # 目录区块开始(容忍 #/##/### 任意级别写法,统一归一为 ##) + if re.match(r'^#{1,6}\s*目录\s*$', stripped): + in_toc = True + out.append('## 目录') + continue + if in_toc: + # 下一个标题出现,目录区块结束 + if re.match(r'^#{1,6}\s', stripped): + in_toc = False + out.append(line) + continue + m = re.match(r'^(\s*[-*+]\s+)(.*)$', line) + if m: + prefix, item = m.group(1), m.group(2) + # 只剥条目开头的标题标记;兼容加粗包裹的写法(**## xxx** → **xxx**), + # 缩进/加粗/其余内容全部原样保留 + item = re.sub(r'^(\*{0,2})\s*#{1,6}\s+', r'\1', item) + out.append(prefix + item) + continue + # 目录区块内的空行 / 其他杂行原样保留 + out.append(line) + continue + out.append(line) + return '\n'.join(out) + + +def build_timestamp_url(platform: str, video_id: str, total_seconds: int) -> str | None: + """按平台拼接「跳转到第 total_seconds 秒」的视频链接。 + + 仅 B 站 / YouTube 支持可靠的时间戳跳转;抖音 / 快手 / 小红书等只能给出 + 视频本身的链接(无时间参数);无法识别的平台返回 None(调用方降级为纯文本)。 + """ + if platform == 'bilibili': + # video_id 形如 BV1xxx 或 BV1xxx_p2(多 P);_p 段转成查询参数 + if "_p" in video_id: + bvid, _, page = video_id.partition("_p") + return f"https://www.bilibili.com/video/{bvid}?p={page}&t={total_seconds}" + return f"https://www.bilibili.com/video/{video_id}?t={total_seconds}" + if platform == 'youtube': + return f"https://www.youtube.com/watch?v={video_id}&t={total_seconds}s" + if platform == 'douyin': + return f"https://www.douyin.com/video/{video_id}" + if platform == 'kuaishou': + return f"https://www.kuaishou.com/short-video/{video_id}" + if platform == 'xiaohongshu': + return f"https://www.xiaohongshu.com/explore/{video_id}" + return None + + +def replace_content_markers(markdown: str, video_id: str, platform: str = 'bilibili') -> str: + """ + 替换 *Content-04:16*、Content-04:16 或 Content-[04:16] 为超链接,跳转到对应平台视频的时间位置 + """ + # 匹配三种形式:*Content-04:16*、Content-04:16、Content-[04:16] + pattern = r"(?:\*?)Content-(?:\[(\d{2}):(\d{2})\]|(\d{2}):(\d{2}))" + + def replacer(match): + mm = match.group(1) or match.group(3) + ss = match.group(2) or match.group(4) + total_seconds = int(mm) * 60 + int(ss) + + url = build_timestamp_url(platform, video_id, total_seconds) + if not url: + # 平台无法拼出链接:降级为纯文本时间,不留下死链 + return f"({mm}:{ss})" + return f"[原片 @ {mm}:{ss}]({url})" + + return re.sub(pattern, replacer, markdown) + diff --git a/backend/app/utils/openai_client.py b/backend/app/utils/openai_client.py new file mode 100644 index 0000000000000000000000000000000000000000..35c9a6a76997abaabd6afe5ae44ef4b1a5d12b32 --- /dev/null +++ b/backend/app/utils/openai_client.py @@ -0,0 +1,55 @@ +"""统一构造 OpenAI 兼容客户端:注入全局代理 + 校验 api_key。 + +为什么要这一层: + - 代理:openai SDK 默认只认进程级 HTTP_PROXY 环境变量,桌面端用户在 UI 里 + 填的代理需要显式塞进 httpx.Client 才生效。 + - api_key 校验:空 key 会让 httpx 拼出非法 header `Bearer `,抛出 + `httpx.LocalProtocolError: Illegal header value b'Bearer '` 这种天书报错。 + 在入口挡掉,给用户「xxx 的 API Key 未配置」这种能看懂的提示。 +""" +from typing import Optional + +from openai import OpenAI + +from app.services.proxy_config_manager import ProxyConfigManager +from app.utils.logger import get_logger + +logger = get_logger(__name__) + + +def build_openai_client( + api_key: Optional[str], + base_url: Optional[str], + *, + key_label: str = "API Key", + timeout: Optional[float] = None, +) -> OpenAI: + """构造 OpenAI 客户端。api_key 为空直接抛清晰错误;代理已配置则注入。 + + key_label 用于错误提示,例如 "Groq 的 API Key" / "OpenAI 供应商的 API Key"。 + """ + if not api_key or not str(api_key).strip(): + raise ValueError(f"{key_label} 未配置,请先在「设置」里填写后再使用") + + kwargs = {"api_key": str(api_key).strip(), "base_url": base_url} + if timeout is not None: + kwargs["timeout"] = timeout + + # 始终显式传入 httpx.Client(trust_env=False): + # 本机环境里常见的 NO_PROXY=::1 会触发 httpx 解析异常 + # `Invalid port: ':1'`,导致 OpenAI 客户端还没发请求就失败。 + # 应用代理仍由 ProxyConfigManager 统一读取并显式注入。 + import httpx + + http_client_kwargs = { + "timeout": timeout or 600.0, + "trust_env": False, + } + proxy_url = ProxyConfigManager().get_proxy_url() + if proxy_url: + http_client_kwargs["proxy"] = proxy_url + logger.info(f"OpenAI 客户端走代理: {proxy_url}") + + kwargs["http_client"] = httpx.Client(**http_client_kwargs) + + return OpenAI(**kwargs) diff --git a/backend/app/utils/path_helper.py b/backend/app/utils/path_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..eaf7d5d72a2f33d29238eae3450ff4f598d647d4 --- /dev/null +++ b/backend/app/utils/path_helper.py @@ -0,0 +1,84 @@ +import os +import sys +from pathlib import Path + +PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) + + +def get_data_dir(): + if getattr(sys, 'frozen', False): + + base_dir = os.path.dirname(sys.executable) + else: + + base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../data")) + + data_path = os.path.join(base_dir, "data") + os.makedirs(data_path, exist_ok=True) + return data_path + + +def get_model_dir(subdir: str = "whisper") -> str: + # 判断是否为打包状态(PyInstaller) + if getattr(sys, 'frozen', False): + # exe 执行,放在 APPDATA 或 ~/.cache 下 + base_dir = os.path.join(os.getenv("APPDATA") or str(Path.home()), "VideoMemo", "models") + else: + # 开发时,相对项目根目录 + base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../models")) + + path = os.path.join(base_dir, subdir) + os.makedirs(path, exist_ok=True) + return path + + +def get_app_dir(subdir: str = "") -> str: + """ + 返回一个稳定的可写目录: + - 开发时:使用项目 data 目录 + - 打包后:使用 exe 所在目录 + """ + if getattr(sys, 'frozen', False): + # 打包后运行:使用 main.exe 所在目录 + base_dir = os.path.dirname(sys.executable) + else: + # 开发模式:使用项目的 /data 目录 + base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../data")) + + full_path = os.path.join(base_dir, subdir) + os.makedirs(full_path, exist_ok=True) + return full_path + + +def get_plugin_packages_dir() -> str: + """用户自装 Python 包的「插件目录」(桌面端可选依赖如 mlx_whisper 装这里)。 + + 打包后的应用是冻结的 Python 3.11,读不到系统 site-packages。 + main.py 启动时会把该目录加进 sys.path(PyInstaller 内置包优先, + 插件目录只补缺失的包),用户用 Python 3.11 安装后重启应用即可生效: + + python3.11 -m pip install --target "<本目录>" mlx_whisper + + 与 get_model_dir 同一基准目录(macOS: ~/VideoMemo,Windows: %APPDATA%/VideoMemo)。 + """ + base_dir = os.path.join(os.getenv("APPDATA") or str(Path.home()), "VideoMemo") + path = os.path.join(base_dir, "python-packages") + os.makedirs(path, exist_ok=True) + return path + + +def get_runtime_dir(name: str) -> str: + """对外提供静态资源(static/uploads)的根目录,保证「写入」与「服务」同源。 + + - 打包(PyInstaller frozen):使用 exe 所在目录的绝对路径。Tauri 以 sidecar 方式拉起 + 后端,其工作目录(cwd)不是项目目录(macOS 上常是只读位置),绝不能用 cwd 相对路径, + 否则截图/封面的写入目录与 /static 挂载目录不一致,桌面端图片全部 404。 + - 开发 / Docker:保持原有的 cwd 相对目录("./static" 等),不改变既有部署与卷挂载行为。 + """ + if getattr(sys, 'frozen', False): + base_dir = os.path.dirname(sys.executable) + else: + base_dir = "." + full_path = os.path.join(base_dir, name) + os.makedirs(full_path, exist_ok=True) + return full_path \ No newline at end of file diff --git a/backend/app/utils/response.py b/backend/app/utils/response.py new file mode 100644 index 0000000000000000000000000000000000000000..c021ef8e0a07290ffe269d0904964bb4f226c0f3 --- /dev/null +++ b/backend/app/utils/response.py @@ -0,0 +1,24 @@ +from fastapi.responses import JSONResponse +from app.utils.status_code import StatusCode +from pydantic import BaseModel +from typing import Optional, Any + + +from fastapi.responses import JSONResponse + +class ResponseWrapper: + @staticmethod + def success(data=None, msg="success", code=0): + return JSONResponse(content={ + "code": code, + "msg": msg, + "data": data + }) + + @staticmethod + def error(msg="error", code=500, data=None): + return JSONResponse(content={ + "code": code, + "msg": str(msg), + "data": data + }) \ No newline at end of file diff --git a/backend/app/utils/screenshot_marker.py b/backend/app/utils/screenshot_marker.py new file mode 100644 index 0000000000000000000000000000000000000000..0ee29405c30f2d4504fea71c30971656dc0bfa1b --- /dev/null +++ b/backend/app/utils/screenshot_marker.py @@ -0,0 +1,59 @@ +import re +from typing import List, Tuple + + +def _format_seconds(total_seconds: int) -> str: + total_seconds = max(0, int(total_seconds)) + return f"{total_seconds // 60:02d}:{total_seconds % 60:02d}" + + +def extract_screenshot_timestamps(markdown: str) -> List[Tuple[str, int]]: + pattern = r"(\*?Screenshot-(?:\[(\d{2}):(\d{2})\]|(\d{2}):(\d{2})))" + results: List[Tuple[str, int]] = [] + for match in re.finditer(pattern, markdown): + mm = match.group(2) or match.group(4) + ss = match.group(3) or match.group(5) + total_seconds = int(mm) * 60 + int(ss) + results.append((match.group(1), total_seconds)) + return results + + +def extract_content_timestamps(markdown: str, limit: int = 4) -> List[int]: + pattern = r"\*?Content-\[(\d{2}):(\d{2})\]" + seen: set[int] = set() + results: List[int] = [] + for match in re.finditer(pattern, markdown): + total_seconds = int(match.group(1)) * 60 + int(match.group(2)) + if total_seconds in seen: + continue + seen.add(total_seconds) + results.append(total_seconds) + if len(results) >= limit: + break + return results + + +def ensure_screenshot_markers(markdown: str, duration: float | int | None, max_markers: int = 3) -> str: + if extract_screenshot_timestamps(markdown) or extract_content_timestamps(markdown, limit=1): + return markdown + + try: + duration_seconds = max(0, int(float(duration or 0))) + except (TypeError, ValueError): + duration_seconds = 0 + + if duration_seconds <= 0: + timestamps = [0] + else: + timestamps = [ + max(0, min(duration_seconds - 1, round(duration_seconds * ratio))) + for ratio in (0.25, 0.5, 0.75) + ][:max_markers] + + unique_timestamps: List[int] = [] + for ts in timestamps: + if ts not in unique_timestamps: + unique_timestamps.append(ts) + + marker_block = "\n\n".join(f"*Screenshot-[{_format_seconds(ts)}]" for ts in unique_timestamps) + return f"{markdown.rstrip()}\n\n## 关键画面\n\n{marker_block}\n" diff --git a/backend/app/utils/status_code.py b/backend/app/utils/status_code.py new file mode 100644 index 0000000000000000000000000000000000000000..6e1c9ecd3506fe83effdb5595f7ecc75e1b933f6 --- /dev/null +++ b/backend/app/utils/status_code.py @@ -0,0 +1,12 @@ +from enum import IntEnum + +class StatusCode(IntEnum): + SUCCESS = 0 + FAIL = 1 + + DOWNLOAD_ERROR = 1001 + TRANSCRIBE_ERROR = 1002 + GENERATE_ERROR = 1003 + + INVALID_URL = 2001 + PARAM_ERROR = 2002 \ No newline at end of file diff --git a/backend/app/utils/url_parser.py b/backend/app/utils/url_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..5276fa44b80235752f50607404bb51cfcb4049d4 --- /dev/null +++ b/backend/app/utils/url_parser.py @@ -0,0 +1,77 @@ +import re +from typing import Optional +import requests + + +# 匹配文本中的第一个 http(s) 链接(贪到首个空白/中文/引号前为止) +_URL_RE = re.compile(r"https?://[^\s一-鿿\"'))】>,。、]+") + + +def clean_url(text: str) -> str: + """从「分享文案」里提取干净的链接。 + + 小红书/抖音/B 站的分享内容常是「标题 + 一堆不可见字符 + 链接」整段, + 直接丢给 yt-dlp 会被当成非法 URL(generic extractor 报 is not a valid URL)。 + 这里:去掉 BOM/零宽等不可见字符,再抓出第一个 http(s) 链接; + 没抓到链接就返回去空白后的原文(兼容本地路径等非 URL 输入)。 + """ + if not text: + return text + # 去掉 BOM/零宽空格/零宽连接符等不可见字符 + cleaned = re.sub(r"[​‌‍⁠]", "", text) + m = _URL_RE.search(cleaned) + if m: + return m.group(0).strip().rstrip(".,;") + return cleaned.strip() + + +def extract_video_id(url: str, platform: str) -> Optional[str]: + """ + 从视频链接中提取视频 ID + + :param url: 视频链接 + :param platform: 平台名(bilibili / youtube / douyin) + :return: 提取到的视频 ID 或 None + """ + if platform == "bilibili": + # 如果是短链接,则解析真实链接 + if "b23.tv" in url: + resolved_url = resolve_bilibili_short_url(url) + if resolved_url: + url = resolved_url + + # 匹配 BV号(如 BV1vc411b7Wa) + match = re.search(r"BV([0-9A-Za-z]+)", url) + return f"BV{match.group(1)}" if match else None + + elif platform == "youtube": + # 匹配 v=xxxxx 或 youtu.be/xxxxx,ID 长度通常为 11 + match = re.search(r"(?:v=|youtu\.be/)([0-9A-Za-z_-]{11})", url) + return match.group(1) if match else None + + elif platform == "douyin": + # 匹配 douyin.com/video/1234567890123456789 + match = re.search(r"/video/(\d+)", url) + return match.group(1) if match else None + + elif platform == "xiaohongshu": + # 匹配 explore/{id} 或 discovery/item/{id},id 通常是 24 位 hex + match = re.search(r"/(?:explore|discovery/item)/([0-9a-fA-F]+)", url) + return match.group(1) if match else None + + return None + + +def resolve_bilibili_short_url(short_url: str) -> Optional[str]: + """ + 解析哔哩哔哩短链接以获取真实视频链接 + + :param short_url: Bilibili短链接(如"https://b23.tv/xxxxxx") + :return: 真实的视频链接或None + """ + try: + response = requests.head(short_url, allow_redirects=True) + return response.url + except requests.RequestException as e: + print(f"Error resolving short URL: {e}") + return None diff --git a/backend/app/utils/video_helper.py b/backend/app/utils/video_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..e601d812f1513c8e1adc32188f1da07a6b0606d1 --- /dev/null +++ b/backend/app/utils/video_helper.py @@ -0,0 +1,66 @@ +import shutil +from pathlib import Path + +from dotenv import load_dotenv +import subprocess +import os +import uuid + +from app.utils.path_helper import get_runtime_dir +load_dotenv() +api_path = os.getenv("API_BASE_URL", "http://localhost") +BACKEND_PORT= os.getenv("BACKEND_PORT", 8483) + +BACKEND_BASE_URL = f"{api_path}:{BACKEND_PORT}" + +from typing import Optional +def generate_screenshot(video_path: str, output_dir: str, timestamp: int, index: int) -> str: + """ + 使用 ffmpeg 生成截图,返回生成图片路径 + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + filename = f"screenshot_{index:03}_{uuid.uuid4()}.jpg" + output_path = output_dir / filename + + command = [ + "ffmpeg", + "-ss", str(timestamp), + "-i", str(video_path), + "-frames:v", "1", + "-q:v", "2", + str(output_path), + "-y" + ] + + print("Running command:", command) + result = subprocess.run(command, capture_output=True, text=True) + + if result.returncode != 0: + print("ffmpeg failed:", result.stderr) + + return str(output_path) + + + +def save_cover_to_static(local_cover_path: str, subfolder: Optional[str] = "cover") -> str: + """ + 将封面图片保存到 static 目录下,并返回前端可访问的路径 + :param local_cover_path: 本地原封面路径(比如提取出来的jpg) + :param subfolder: 子目录,默认是 cover,可以自定义 + :return: 前端访问路径,例如 /static/cover/xxx.jpg + """ + # 目标子目录:用 get_runtime_dir 取与 /static 挂载同源的 static 目录(开发/Docker=./static, + # 打包=exe 同级),不能用 os.getcwd()——Tauri sidecar 的 cwd 不是项目目录,否则封面写入位置 + # 与 /static 服务目录不一致,桌面端封面 404。 + target_dir = get_runtime_dir(os.path.join("static", subfolder or "cover")) + + # 拷贝文件 + file_name = os.path.basename(local_cover_path) + target_path = os.path.join(target_dir, file_name) + shutil.copy2(local_cover_path, target_path) # 保留原时间戳、权限 + image_relative_path = f"/static/{subfolder}/{file_name}".replace("\\", "/") + url_path = f"{BACKEND_BASE_URL.rstrip('/')}/{image_relative_path.lstrip('/')}" + # 返回前端可访问的路径 + return url_path diff --git a/backend/app/utils/video_reader.py b/backend/app/utils/video_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..100dbff61e0405bc77d55c6b97b1487a0975effb --- /dev/null +++ b/backend/app/utils/video_reader.py @@ -0,0 +1,183 @@ +import base64 +import hashlib +import os +import re +import subprocess +from concurrent.futures import ThreadPoolExecutor, as_completed +import ffmpeg +from PIL import Image, ImageDraw, ImageFont + +from app.utils.logger import get_logger +from app.utils.path_helper import get_app_dir + +logger = get_logger(__name__) +class VideoReader: + def __init__(self, + video_path: str, + grid_size=(3, 3), + frame_interval=2, + dedupe_enabled=True, + unit_width=960, + unit_height=540, + save_quality=90, + font_path="fonts/arial.ttf", + frame_dir=None, + grid_dir=None): + self.video_path = video_path + self.grid_size = grid_size + self.frame_interval = frame_interval + self.dedupe_enabled = dedupe_enabled + self.unit_width = unit_width + self.unit_height = unit_height + self.save_quality = save_quality + self.frame_dir = frame_dir or get_app_dir("output_frames") + self.grid_dir = grid_dir or get_app_dir("grid_output") + print(f"视频路径:{video_path}",self.frame_dir,self.grid_dir) + self.font_path = font_path + + @staticmethod + def _calculate_file_md5(file_path: str) -> str: + hasher = hashlib.md5() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + hasher.update(chunk) + return hasher.hexdigest() + + def format_time(self, seconds: float) -> str: + mm = int(seconds // 60) + ss = int(seconds % 60) + return f"{mm:02d}_{ss:02d}" + + def extract_time_from_filename(self, filename: str) -> float: + match = re.search(r"frame_(\d{2})_(\d{2})\.jpg", filename) + if match: + mm, ss = map(int, match.groups()) + return mm * 60 + ss + return float('inf') + + def _extract_single_frame(self, ts: int) -> str | None: + """提取单帧,返回输出路径或 None(失败时)。""" + time_label = self.format_time(ts) + output_path = os.path.join(self.frame_dir, f"frame_{time_label}.jpg") + cmd = ["ffmpeg", "-ss", str(ts), "-i", self.video_path, "-frames:v", "1", "-q:v", "2", "-y", output_path, + "-hide_banner", "-loglevel", "error"] + try: + subprocess.run(cmd, check=True) + return output_path + except subprocess.CalledProcessError: + return None + + def extract_frames(self, max_frames=1000) -> list[str]: + + try: + os.makedirs(self.frame_dir, exist_ok=True) + duration = float(ffmpeg.probe(self.video_path)["format"]["duration"]) + timestamps = [i for i in range(0, int(duration), self.frame_interval)][:max_frames] + + # 并行提取帧 + max_workers = min(os.cpu_count() or 4, 8, len(timestamps)) + frame_results: dict[int, str | None] = {} + with ThreadPoolExecutor(max_workers=max_workers) as pool: + futures = {pool.submit(self._extract_single_frame, ts): ts for ts in timestamps} + for future in as_completed(futures): + ts = futures[future] + frame_results[ts] = future.result() + + # 按时间戳顺序整理结果,并进行去重 + image_paths = [] + last_hash = None + for ts in timestamps: + output_path = frame_results.get(ts) + if not output_path or not os.path.exists(output_path): + continue + + if self.dedupe_enabled: + frame_hash = self._calculate_file_md5(output_path) + if frame_hash == last_hash: + os.remove(output_path) + continue + last_hash = frame_hash + + image_paths.append(output_path) + return image_paths + except Exception as e: + logger.error(f"分割帧发生错误:{str(e)}") + raise ValueError("视频处理失败") + + def group_images(self) -> list[list[str]]: + image_files = [os.path.join(self.frame_dir, f) for f in os.listdir(self.frame_dir) if + f.startswith("frame_") and f.endswith(".jpg")] + image_files.sort(key=lambda f: self.extract_time_from_filename(os.path.basename(f))) + group_size = self.grid_size[0] * self.grid_size[1] + return [image_files[i:i + group_size] for i in range(0, len(image_files), group_size)] + + def concat_images(self, image_paths: list[str], name: str) -> str: + os.makedirs(self.grid_dir, exist_ok=True) + font = ImageFont.truetype(self.font_path, 48) if os.path.exists(self.font_path) else ImageFont.load_default() + images = [] + + for path in image_paths: + img = Image.open(path).convert("RGB").resize((self.unit_width, self.unit_height), Image.Resampling.LANCZOS) + timestamp = re.search(r"frame_(\d{2})_(\d{2})\.jpg", os.path.basename(path)) + time_text = f"{timestamp.group(1)}:{timestamp.group(2)}" if timestamp else "" + draw = ImageDraw.Draw(img) + draw.text((10, 10), time_text, fill="yellow", font=font, stroke_width=1, stroke_fill="black") + images.append(img) + + cols, rows = self.grid_size + grid_img = Image.new("RGB", (self.unit_width * cols, self.unit_height * rows), (255, 255, 255)) + + for i, img in enumerate(images): + x = (i % cols) * self.unit_width + y = (i // cols) * self.unit_height + grid_img.paste(img, (x, y)) + + save_path = os.path.join(self.grid_dir, f"{name}.jpg") + grid_img.save(save_path, quality=self.save_quality) + return save_path + + def encode_images_to_base64(self, image_paths: list[str]) -> list[str]: + base64_images = [] + for path in image_paths: + with open(path, "rb") as img_file: + encoded_string = base64.b64encode(img_file.read()).decode("utf-8") + base64_images.append(f"data:image/jpeg;base64,{encoded_string}") + return base64_images + + def run(self)->list[str]: + logger.info("开始提取视频帧...") + try: + # 确保目录存在 + print(self.frame_dir,self.grid_dir) + os.makedirs(self.frame_dir, exist_ok=True) + os.makedirs(self.grid_dir, exist_ok=True) + #清空帧文件夹 + for file in os.listdir(self.frame_dir): + if file.startswith("frame_"): + os.remove(os.path.join(self.frame_dir, file)) + print(self.frame_dir,self.grid_dir) + #清空网格文件夹 + for file in os.listdir(self.grid_dir): + if file.startswith("grid_"): + os.remove(os.path.join(self.grid_dir, file)) + print(self.frame_dir,self.grid_dir) + self.extract_frames() + print("2#3",self.frame_dir,self.grid_dir) + logger.info("开始拼接网格图...") + image_paths = [] + groups = self.group_images() + for idx, group in enumerate(groups, start=1): + if len(group) < self.grid_size[0] * self.grid_size[1]: + logger.warning(f"⚠️ 跳过第 {idx} 组,图片不足 {self.grid_size[0] * self.grid_size[1]} 张") + continue + out_path = self.concat_images(group, f"grid_{idx}") + image_paths.append(out_path) + + logger.info("📤 开始编码图像...") + urls = self.encode_images_to_base64(image_paths) + return urls + except Exception as e: + logger.error(f"发生错误:{str(e)}") + raise ValueError("视频处理失败") + + diff --git a/backend/app/validators/__init__.py b/backend/app/validators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backend/app/validators/video_url_validator.py b/backend/app/validators/video_url_validator.py new file mode 100644 index 0000000000000000000000000000000000000000..5fa6c9769f9c6bb677f2ab8f758a293993450f87 --- /dev/null +++ b/backend/app/validators/video_url_validator.py @@ -0,0 +1,51 @@ +from pydantic import AnyUrl, validator, BaseModel, field_validator +import re +from urllib.parse import urlparse + +SUPPORTED_PLATFORMS = { + "bilibili": r"(https?://)?(www\.)?bilibili\.com/video/[a-zA-Z0-9]+", + "youtube": r"(https?://)?(www\.)?(youtube\.com/watch\?v=|youtu\.be/)[\w\-]+", + "douyin": "douyin", + "kuaishou": "kuaishou", + "xiaohongshu": "xiaohongshu", # 子串匹配,小红书 explore/discovery 都能命中 +} + + +def is_supported_video_url(url: str) -> bool: + parsed = urlparse(url) + + # 检查是否为Bilibili的短链接 + if parsed.netloc == "b23.tv": + return True + # 小红书短链 + if parsed.netloc.endswith("xhslink.com"): + return True + + for name, pattern in SUPPORTED_PLATFORMS.items(): + if pattern in ["douyin", "kuaishou", "xiaohongshu"]: + if pattern in url: + return True + else: + if re.match(pattern, url): + return True + + # 兜底:检查用户自定义的平台 + try: + from app.services.custom_platform_manager import match_custom_platform + if match_custom_platform(url): + return True + except Exception: + pass + + return False + + +class VideoRequest(BaseModel): + url: AnyUrl + platform: str + + @field_validator("url") + def validate_video_url(cls, v): + if not is_supported_video_url(str(v)): + raise ValueError("暂不支持该视频平台或链接格式无效") + return v diff --git a/backend/build.bat b/backend/build.bat new file mode 100644 index 0000000000000000000000000000000000000000..ebd12382138e70e53bba45c34d235a80bf08551e --- /dev/null +++ b/backend/build.bat @@ -0,0 +1,70 @@ +@echo off +setlocal enabledelayedexpansion + +REM 切换到脚本所在目录的上级,也就是项目根目录 +cd /d %~dp0.. +echo 当前工作目录:%cd% + +REM 清理旧的构建 +echo 清理旧的构建... +if exist backend\dist rmdir /s /q backend\dist +if exist backend\build rmdir /s /q backend\build +if exist VideoMemo_frontend\src-tauri\bin rmdir /s /q VideoMemo_frontend\src-tauri\bin +echo 清理完成。 + +REM 重新创建 Tauri 需要的目录结构 +mkdir VideoMemo_frontend\src-tauri\bin + +REM 获取 Rust 的 target triple(适配 Tauri 对应平台) +for /f "tokens=2 delims=:" %%A in ('rustc -Vv ^| findstr "host"') do ( + set "TARGET_TRIPLE=%%A" +) +set "TARGET_TRIPLE=%TARGET_TRIPLE: =%" +echo Detected target triple: %TARGET_TRIPLE% + + +REM --- 核心修改部分开始 --- + +REM 步骤 1: 为了避免 PyInstaller 的解析歧义,我们先手动复制文件 +echo 为打包准备 .env 文件... +copy backend\.env.example backend\.env + +REM 步骤 2: 执行 PyInstaller 打包,直接添加已存在的 .env 文件 +echo 开始 PyInstaller 打包... +pyinstaller ^ + -y ^ + --name VideoMemoBackend ^ + --paths backend ^ + --distpath VideoMemo_frontend\src-tauri\bin ^ + --workpath backend\build ^ + --specpath backend ^ + --hidden-import uvicorn ^ + --hidden-import fastapi ^ + --hidden-import starlette ^ + --hidden-import chromadb.api.rust ^ + --collect-all chromadb ^ + --collect-all chromadb_rust_bindings ^ + --exclude-module torch ^ + --exclude-module torchvision ^ + --exclude-module torchaudio ^ + --exclude-module mlx_whisper.torch_whisper ^ + --add-data "app\db\builtin_providers.json;." ^ + --add-data ".env;." ^ + backend\main.py + +REM 步骤 3: 清理在项目根目录创建的临时 .env 文件 +echo 清理临时的 .env 文件... +del backend\.env + +REM --- 核心修改部分结束 --- + + +REM 重命名生成的可执行文件为符合 Tauri 要求的名称 +move /Y VideoMemo_frontend\src-tauri\bin\VideoMemoBackend\VideoMemoBackend.exe VideoMemo_frontend\src-tauri\bin\VideoMemoBackend\VideoMemoBackend-%TARGET_TRIPLE%.exe + +echo PyInstaller 打包完成: +dir VideoMemo_frontend\src-tauri\bin\VideoMemoBackend + +echo 请检查 VideoMemo_frontend\src-tauri\bin\VideoMemoBackend 目录,确认其中包含了名为 .env 的【文件】。 + +endlocal \ No newline at end of file diff --git a/backend/build.sh b/backend/build.sh new file mode 100755 index 0000000000000000000000000000000000000000..a83500bab2dddbb0b31255dfd521f61f1dc961c1 --- /dev/null +++ b/backend/build.sh @@ -0,0 +1,66 @@ +#!/usr/bin/env bash +set -e +# uncomment this for debugging +# set -x + +# 切到项目根(假设脚本放在 script/ 目录) +cd "$(dirname "$0")/.." + +echo "当前工作目录:$(pwd)" + +# 清理旧的构建 +echo "清理旧的构建..." +rm -rf backend/dist backend/build ./VideoMemo_frontend/src-tauri/bin/* +echo "清理完成。" + +TARGET_TRIPLE=$(rustc -Vv | grep host | cut -f2 -d' ') +echo "Detected target triple: $TARGET_TRIPLE" + +# --- 核心修改部分开始 --- + +# 步骤 1: 为了避免 PyInstaller 的解析歧义,我们先手动复制文件 +echo "为打包准备 .env 文件..." +cp backend/.env.example backend/.env + +# 步骤 2: PyInstaller 打包,直接添加已存在的 .env 文件 +echo "开始 PyInstaller 打包..." +python -m PyInstaller \ + -y \ + --name VideoMemoBackend \ + --paths backend \ + --distpath ./VideoMemo_frontend/src-tauri/bin \ + --workpath backend/build \ + --specpath backend \ + --hidden-import uvicorn \ + --hidden-import fastapi \ + --hidden-import starlette \ + --hidden-import chromadb.api.rust \ + --collect-all chromadb \ + --collect-all chromadb_rust_bindings \ + --exclude-module torch \ + --exclude-module torchvision \ + --exclude-module torchaudio \ + --exclude-module funasr \ + --exclude-module mlx_whisper.torch_whisper \ + --add-data "app/db/builtin_providers.json:." \ + --add-data ".env:." \ + "$(pwd)/backend/main.py" + +# 步骤 3: 清理在项目根目录创建的临时 .env 文件 +echo "清理临时的 .env 文件..." +rm backend/.env + +# --- 核心修改部分结束 --- + + +# 重命名主执行文件以包含目标平台信息 +mv \ + ./VideoMemo_frontend/src-tauri/bin/VideoMemoBackend/VideoMemoBackend\ + ./VideoMemo_frontend/src-tauri/bin/VideoMemoBackend/VideoMemoBackend-$TARGET_TRIPLE + +echo "PyInstaller 打包完成。" +echo "打包后的目录内容:" +ls -l ./VideoMemo_frontend/src-tauri/bin/VideoMemoBackend + +echo "请检查 src-tauri/bin/VideoMemoBackend 目录,确认其中包含了名为 .env 的【文件】。" + diff --git a/backend/events/__init__.py b/backend/events/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ea4e810d425bc2963455ab43bab6584ab16e2cd5 --- /dev/null +++ b/backend/events/__init__.py @@ -0,0 +1,14 @@ +# 注册监听器 +from app.utils.logger import get_logger +from events.handlers import cleanup_temp_files +from events.signals import transcription_finished + +logger = get_logger(__name__) + +def register_handler(): + try: + transcription_finished.connect(cleanup_temp_files) + logger.info("注册监听器成功") + except Exception as e: + logger.error(f"注册监听器失败:{e}") + diff --git a/backend/events/handlers.py b/backend/events/handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..e4b00154436d7599f3b81a425fc704f0fc554fc0 --- /dev/null +++ b/backend/events/handlers.py @@ -0,0 +1,25 @@ +import os +from app.utils.logger import get_logger +logger = get_logger(__name__) + +def cleanup_temp_files(data): + logger.info(f"starting cleanup temp files :{data['file_path']}") + file_path = data['file_path'] + if not os.path.exists(file_path): + logger.warning(f"路径不存在:{file_path}") + return + + dir_path = os.path.dirname(file_path) + base_name = os.path.basename(file_path) + video_id, _ = os.path.splitext(base_name) + + logger.info(f"开始清理 video_id={video_id} 所有相关文件") + + for file in os.listdir(dir_path): + if file.startswith(video_id): + full_path = os.path.join(dir_path, file) + try: + os.remove(full_path) + logger.info(f"删除文件:{full_path}") + except Exception as e: + logger.error(f"删除失败:{full_path},原因:{e}") diff --git a/backend/events/signals.py b/backend/events/signals.py new file mode 100644 index 0000000000000000000000000000000000000000..18e56321e1b207bab16fce722c779661ad9802a2 --- /dev/null +++ b/backend/events/signals.py @@ -0,0 +1,2 @@ +from blinker import signal +transcription_finished = signal("transcription_finished") \ No newline at end of file diff --git a/backend/ffmpeg_helper.py b/backend/ffmpeg_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..1254ee683f4b8193ca59c34b159f3cbbf469f167 --- /dev/null +++ b/backend/ffmpeg_helper.py @@ -0,0 +1,79 @@ +import os +import subprocess +import sys +from dotenv import load_dotenv + +from app.utils.logger import get_logger +logger = get_logger(__name__) + + +def _load_dotenv_from_multiple_paths(): + """尝试多个位置加载 .env,适配源码运行和 PyInstaller 打包场景。 + + PyInstaller 打包后当前工作目录是 EXE 所在目录,而源码运行时 .env + 通常在项目根目录或 backend/ 同级。遍历常见候选路径确保能命中。 + """ + candidates = [] + # 1. 当前工作目录(EXE 所在目录) + candidates.append(os.path.join(os.getcwd(), '.env')) + # 2. 本脚本所在目录(backend/) + script_dir = os.path.dirname(os.path.abspath(__file__)) + candidates.append(os.path.join(script_dir, '.env')) + # 3. 项目根目录(backend/../.env) + candidates.append(os.path.join(script_dir, '..', '.env')) + # 4. PyInstaller 打包后的 _internal/ 子目录 + if getattr(sys, 'frozen', False): + exe_dir = os.path.dirname(sys.executable) + candidates.append(os.path.join(exe_dir, '_internal', '.env')) + + for path in candidates: + normalized = os.path.normpath(path) + if os.path.isfile(normalized): + load_dotenv(normalized) + return + # 都没找到,fallback 到默认行为(从 CWD 找) + load_dotenv() + + +_load_dotenv_from_multiple_paths() +def check_ffmpeg_exists() -> bool: + """ + 检查 ffmpeg 是否可用。优先使用 FFMPEG_BIN_PATH 环境变量指定的路径。 + """ + ffmpeg_bin_path = os.getenv("FFMPEG_BIN_PATH") + logger.info(f"FFMPEG_BIN_PATH: {ffmpeg_bin_path}") + if ffmpeg_bin_path and os.path.isdir(ffmpeg_bin_path): + os.environ["PATH"] = ffmpeg_bin_path + os.pathsep + os.environ.get("PATH", "") + logger.info(f"使用FFMPEG_BIN_PATH: {ffmpeg_bin_path}") + else: + # 遍历系统PATH寻找ffmpeg.exe + system_path = os.environ.get("PATH", "") + path_dirs = system_path.split(os.pathsep) + for path_dir in path_dirs: + ffmpeg_exe_path = os.path.join(path_dir, "ffmpeg.exe") + if os.path.isfile(ffmpeg_exe_path): + os.environ["PATH"] = path_dir + os.pathsep + system_path + logger.info(f"在系统PATH中找到ffmpeg: {path_dir}") + break + try: + subprocess.run(["ffmpeg", "-version"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True) + logger.info("ffmpeg 已安装") + return True + except (FileNotFoundError, OSError, subprocess.CalledProcessError): + logger.info("ffmpeg 未安装") + return False + + +def ensure_ffmpeg_or_raise(): + """ + 校验 ffmpeg 是否可用,否则抛出异常并提示安装方式。 + """ + if not check_ffmpeg_exists(): + logger.error("未检测到 ffmpeg,请先安装后再使用本功能。") + raise EnvironmentError( + " 未检测到 ffmpeg,请先安装后再使用本功能。\n" + "👉 下载地址:https://ffmpeg.org/download.html\n" + "🪟 Windows 推荐:https://www.gyan.dev/ffmpeg/builds/\n" + "💡 如果你已安装,请将其路径写入 `.env` 文件,例如:\n" + "FFMPEG_BIN_PATH=/your/custom/ffmpeg/bin" + ) diff --git a/backend/main.py b/backend/main.py new file mode 100644 index 0000000000000000000000000000000000000000..82f1b6070500af5727b645f6bc37b51e7274a88f --- /dev/null +++ b/backend/main.py @@ -0,0 +1,129 @@ +import os +import sys +from contextlib import asynccontextmanager +from pathlib import Path + +# ─── 插件目录(必须在 import app.* 之前执行)──────────────────────────── +# 桌面端(PyInstaller 冻结)读不到系统 site-packages。把用户插件目录加进 +# sys.path,让用户可以自装可选依赖(如 mlx_whisper): +# python3.11 -m pip install --target "<插件目录>" mlx_whisper +# PyInstaller 的 FrozenImporter 优先于 sys.path,内置包不会被插件目录覆盖, +# 插件目录只补「包里没有」的模块。安装后重启应用生效。 +if getattr(sys, "frozen", False): + _plugin_dir = os.path.join( + os.getenv("APPDATA") or str(Path.home()), "VideoMemo", "python-packages" + ) + os.makedirs(_plugin_dir, exist_ok=True) + if _plugin_dir not in sys.path: + sys.path.insert(0, _plugin_dir) + +import uvicorn +from fastapi import FastAPI +from starlette.middleware.cors import CORSMiddleware +from starlette.middleware.gzip import GZipMiddleware +from starlette.staticfiles import StaticFiles +from dotenv import load_dotenv + +from app.db.init_db import init_db +from app.db.provider_dao import seed_default_providers +from app.exceptions.exception_handlers import register_exception_handlers +# from app.db.model_dao import init_model_table +# from app.db.provider_dao import init_provider_table +from app.utils.logger import get_logger +from app.utils.path_helper import get_runtime_dir +from app import create_app +from app.services.transcriber_config_manager import TranscriberConfigManager +from app.services.scheduler import get_scheduler +from events import register_handler +from ffmpeg_helper import ensure_ffmpeg_or_raise + +logger = get_logger(__name__) +load_dotenv() + +# 读取 .env 中的路径 +static_path = os.getenv('STATIC', '/static') + +# 静态资源根目录用 get_runtime_dir:开发/Docker 维持 "./static",PyInstaller 打包后切到 +# exe 同级目录。挂载目录必须和 note.py / video_helper.py 的写入目录同源——Tauri sidecar 的 +# cwd 不是项目目录,之前用相对 "static" 会让两者错位,导致桌面端正文截图和侧边栏封面全 404。 +static_dir = get_runtime_dir("static") +uploads_dir = get_runtime_dir("uploads") + +@asynccontextmanager +async def lifespan(app: FastAPI): + # 启动序列拆成 5 步、每步独立日志 + 异常时打明确的 [startup N/5 FAILED] 标记。 + # 目的:用户 docker logs 一眼能看出后端死在哪一步,避免「容器一直重启但看不出原因」。 + try: + logger.info("[startup 1/5] register_handler() — 注册事件处理器") + register_handler() + + logger.info("[startup 2/5] init_db() — 初始化 SQLite 数据库") + init_db() + + logger.info("[startup 3/5] TranscriberConfigManager — 读取转写器配置") + # 转写器不再在启动时强制初始化,而是在首次生成笔记时按需创建。 + # 如果配置了不可用的类型(如 mlx-whisper 未安装),会在使用时报错而非静默回退。 + _cfg = TranscriberConfigManager().get_config() + logger.info( + f" 当前转写器: type={_cfg['transcriber_type']}, " + f"model_size={_cfg['whisper_model_size']}" + ) + + logger.info("[startup 4/5] seed_default_providers() — 初始化默认 LLM 供应商") + seed_default_providers() + + logger.info("[startup 5/5] 启动完成,等待请求") + get_scheduler().start() + except Exception: + logger.exception("[startup FAILED] 后端启动期异常,详见堆栈;容器会退出并由 restart 策略决定是否重试") + raise + + yield + + get_scheduler().stop() + +app = create_app(lifespan=lifespan) + +# 允许的源:本地 web 端 + Tauri 桌面端 + 浏览器扩展(chrome/edge/firefox) +# 用 regex 是因为 chrome-extension:// 的 id 在每次开发版加载时不固定 +# Tauri 2 不同平台 webview origin 不一样,必须全列: +# - macOS: tauri://localhost (自定义协议) +# - Windows: https://tauri.localhost (Edge WebView2) +# - Linux: http://tauri.localhost (WebKitGTK) +# 漏掉哪个都会导致桌面端 fetch 返回 200 但 browser 因为 CORS 拒绝读响应, +# 表现为前端「连不上后端」但后端日志一片 200 OK。 +CORS_ORIGIN_REGEX = ( + r"^chrome-extension://[a-z]+$" + r"|^moz-extension://.+$" + r"|^http://(localhost|127\.0\.0\.1)(:\d+)?$" + r"|^tauri://localhost$" + r"|^https?://tauri\.localhost$" + # Cloudflare Pages:.pages.dev 及其预览子域 ..pages.dev + r"|^https://([a-z0-9-]+\.)*pages\.dev$" +) + +app.add_middleware( + CORSMiddleware, + allow_origin_regex=CORS_ORIGIN_REGEX, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) +app.add_middleware(GZipMiddleware, minimum_size=1000) +register_exception_handlers(app) +app.mount(static_path, StaticFiles(directory=static_dir), name="static") +app.mount("/uploads", StaticFiles(directory=uploads_dir), name="uploads") + + + + + + + + + +if __name__ == "__main__": + port = int(os.getenv("BACKEND_PORT", 8483)) + host = os.getenv("BACKEND_HOST", "0.0.0.0") + logger.info(f"Starting server on {host}:{port}") + uvicorn.run(app, host=host, port=port, reload=False) \ No newline at end of file diff --git a/backend/requirements.txt b/backend/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..07ab8481aac58b4008a24eae65cc867a9327427e --- /dev/null +++ b/backend/requirements.txt @@ -0,0 +1,134 @@ +aiohappyeyeballs==2.6.1 +aiohttp==3.11.16 +aiosignal==1.3.2 +altgraph==0.17.4 +amqp==5.3.1 +annotated-types==0.7.0 +anyio==4.9.0 +async-timeout==5.0.1 +attrs==25.3.0 +av==14.2.0 +beautifulsoup4==4.13.4 +billiard==4.2.1 +blinker==1.9.0 +Brotli==1.1.0 +celery==5.5.1 +certifi==2025.1.31 +cffi==1.17.1 +charset-normalizer==3.4.1 +chromadb>=0.5.0 +click==8.1.8 +click-didyoumean==0.3.1 +click-plugins==1.1.1 +click-repl==0.3.0 +colorama==0.4.6 +coloredlogs==15.0.1 +cssselect2==0.8.0 +ctranslate2==4.6.0 +distro==1.9.0 +dnspython==2.7.0 +email_validator==2.2.0 +exceptiongroup==1.2.2 +fastapi==0.115.12 +fastapi-cli==0.0.7 +faster-whisper==1.1.1 +# 可选转写引擎 FunASR(阿里·中文,效果常优于 Whisper):默认不安装,依赖较重(含 torch,约 2GB)。 +# 需要时手动安装后即可在「音频转写配置」选用:pip install funasr torch torchaudio +ffmpeg-python==0.2.0 +filelock==3.18.0 +flatbuffers==25.2.10 +fonttools==4.58.4 +frozenlist==1.5.0 +fsspec==2025.3.2 +future==1.0.0 +gmssl==3.2.2 +h11==0.14.0 +hf-xet==1.0.0 +httpcore==1.0.7 +httptools==0.6.4 +httpx==0.28.1 +huggingface-hub==0.30.2 +humanfriendly==10.0 +humanize==4.12.2 +idna==3.10 +Jinja2==3.1.6 +jiter==0.9.0 +kombu==5.5.2 +lxml==5.4.0 +macholib==1.16.3 +Markdown==3.8 +markdown-it-py==3.0.0 +markdown_pdf==1.7 +MarkupSafe==3.0.2 +mdurl==0.1.2 +modelscope==1.25.0 +mpmath==1.3.0 +multidict==6.4.3 +networkx==3.3 +numpy==2.2.4 +onnxruntime==1.21.0 +openai==1.70.0 +orjson==3.10.16 +packaging==24.2 +pdfkit==1.0.0 +pefile==2023.2.7 +pillow==11.0.0 +prometheus_client==0.21.1 +prompt_toolkit==3.0.50 +propcache==0.3.1 +protobuf==6.30.2 +# PostgreSQL 驱动:DATABASE_URL 指向 Postgres(如 Supabase)时由 SQLAlchemy 使用; +# 用 SQLite 时不加载,留着无副作用。 +psycopg2-binary==2.9.10 +pycparser==2.22 +pycryptodomex==3.22.0 +pydantic==2.11.2 +pydantic_core==2.33.1 +pydyf==0.11.0 +PyExecJS==1.5.1 +Pygments==2.19.1 +pyinstaller==6.13.0 +pyinstaller-hooks-contrib==2025.2 +PyMuPDF==1.25.3 +pyphen==0.17.2 +pyreadline3==3.5.4 +python-dateutil==2.9.0.post0 +python-docx==1.2.0 +python-dotenv==1.1.0 +python-multipart==0.0.20 +pytz==2025.2 +pywin32-ctypes==0.2.3 +PyYAML==6.0.2 +redis==5.2.1 +requests==2.32.3 +rich==14.0.0 +rich-toolkit==0.14.1 +shellingham==1.5.4 +six==1.17.0 +sniffio==1.3.1 +soupsieve==2.7 +starlette==0.46.1 +sympy==1.13.1 +SQLAlchemy==2.0.41 +tenacity==9.1.2 +tinycss2==1.4.0 +tinyhtml5==2.0.0 +tokenizers==0.21.1 +tornado==6.4.2 +tqdm==4.67.1 +typer==0.15.2 +typing-inspection==0.4.0 +tzdata==2025.2 +urllib3==2.3.0 +uvicorn==0.34.0 +uvloop==0.21.0; sys_platform != "win32" +vine==5.1.0 +watchfiles==1.0.4 +wcwidth==0.2.13 +weasyprint==65.1 +webencodings==0.5.1 +websockets==15.0.1 +yarl==1.19.0 +youtube-transcript-api>=1.0.0 +yt-dlp==2025.3.31 +zopfli==0.2.3.post1 diff --git a/backend/run.bat b/backend/run.bat new file mode 100644 index 0000000000000000000000000000000000000000..7a84f979de9940293ba467a8e1c109dc6c70485e --- /dev/null +++ b/backend/run.bat @@ -0,0 +1 @@ +python main.py \ No newline at end of file diff --git a/backend/tests/__init__.py b/backend/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/backend/tests/test_article_dao.py b/backend/tests/test_article_dao.py new file mode 100644 index 0000000000000000000000000000000000000000..600dd17145e21e2d03ba5c8dd1bd794cc7b14f52 --- /dev/null +++ b/backend/tests/test_article_dao.py @@ -0,0 +1,62 @@ +import importlib + +from app.article_fetchers.base import ArticleContent + + +def _load_article_dao(tmp_path, monkeypatch): + monkeypatch.setenv("DATABASE_URL", f"sqlite:///{tmp_path / 'articles.db'}") + import app.db.engine as engine + import app.db.models.articles as article_models + import app.db.init_db as init_db + + importlib.reload(engine) + importlib.reload(article_models) + importlib.reload(init_db) + init_db.init_db() + + import app.db.article_dao as article_dao + + return importlib.reload(article_dao) + + +def test_upsert_article_item_dedupes_by_platform_and_article_id(tmp_path, monkeypatch): + article_dao = _load_article_dao(tmp_path, monkeypatch) + article = ArticleContent( + platform="wechat_mp", + url="https://mp.weixin.qq.com/s/a", + article_id="biz:mid:1:sn", + title="标题", + author_name="公众号", + content_text="正文", + ) + + first = article_dao.upsert_article_item(article) + second = article_dao.upsert_article_item(article) + + assert first.id == second.id + assert len(article_dao.list_article_items()) == 1 + + +def test_create_subscription_and_link_item(tmp_path, monkeypatch): + article_dao = _load_article_dao(tmp_path, monkeypatch) + article = article_dao.upsert_article_item( + ArticleContent( + platform="xiaohongshu", + url="https://www.xiaohongshu.com/explore/a", + article_id="a", + title="小红书标题", + content_text="正文", + ) + ) + + subscription = article_dao.create_subscription( + platform="xiaohongshu", + subscription_type="keyword", + query="AI", + label="AI", + ) + article_dao.link_subscription_item(subscription.id, article.id, "keyword:AI") + + assert article_dao.list_subscriptions()[0].query == "AI" + assert article_dao.get_article_item(article.id).title == "小红书标题" + assert article_dao.list_article_items(subscription_id=subscription.id)[0].id == article.id diff --git a/backend/tests/test_article_fetchers_wechat.py b/backend/tests/test_article_fetchers_wechat.py new file mode 100644 index 0000000000000000000000000000000000000000..0e32485bb9d2661f7ebec240739323664a5f5456 --- /dev/null +++ b/backend/tests/test_article_fetchers_wechat.py @@ -0,0 +1,75 @@ +from app.article_fetchers.wechat import parse_wechat_article_html +from app.article_fetchers.wechat import parse_wechat_search_html + + +WECHAT_HTML = """ + + ignored + +

一篇公众号文章标题

+ VideoMemo实验室 + 2026-06-08 +
+

第一段正文

+

第二段正文

+ +
+ + + +""" + + +def test_parse_wechat_article_extracts_core_fields(): + article = parse_wechat_article_html( + WECHAT_HTML, + "https://mp.weixin.qq.com/s/example", + ) + + assert article.platform == "wechat_mp" + assert article.article_id == "MzExample:123456:1:abcdef" + assert article.title == "一篇公众号文章标题" + assert article.author_name == "VideoMemo实验室" + assert article.published_at == "2026-06-08" + assert "第一段正文" in article.content_text + assert "第二段正文" in article.content_text + assert article.image_urls == ["https://mmbiz.qpic.cn/example.jpg"] + + +def test_parse_wechat_article_fails_when_body_is_empty(): + html = '

标题

' + + try: + parse_wechat_article_html(html, "https://mp.weixin.qq.com/s/empty") + except ValueError as exc: + assert "正文" in str(exc) + else: + raise AssertionError("expected parser to reject empty article body") + + +def test_parse_wechat_search_html_extracts_article_results(): + html = """ + + +
+ AI工具清单 +

公众号作者

+

正文摘要

+
+ + + """ + + items = parse_wechat_search_html(html, "AI工具", limit=5) + + assert len(items) == 1 + assert items[0].platform == "wechat_mp" + assert items[0].title == "AI工具清单" + assert items[0].url == "https://mp.weixin.qq.com/s/abc" + assert items[0].author_name == "公众号作者" + assert items[0].content_text == "正文摘要" diff --git a/backend/tests/test_article_fetchers_xiaohongshu.py b/backend/tests/test_article_fetchers_xiaohongshu.py new file mode 100644 index 0000000000000000000000000000000000000000..6f743e1b5a4f3dce17fe64c74c4b8c3ae7e5b512 --- /dev/null +++ b/backend/tests/test_article_fetchers_xiaohongshu.py @@ -0,0 +1,103 @@ +from app.article_fetchers.xiaohongshu import ( + parse_xiaohongshu_article_html, + parse_xiaohongshu_discovery_html, +) + + +XHS_HTML = """ + + + + + +""" + + +def test_parse_xiaohongshu_article_extracts_embedded_note(): + article = parse_xiaohongshu_article_html( + XHS_HTML, + "https://www.xiaohongshu.com/explore/abc123", + ) + + assert article.platform == "xiaohongshu" + assert article.article_id == "abc123" + assert article.title == "小红书图文标题" + assert article.author_name == "作者A" + assert article.author_id == "u1" + assert "第一段" in article.content_text + assert "第二段" in article.content_text + assert article.cover_url == "https://sns-img-qc.xhscdn.com/a.jpg" + assert article.image_urls == [ + "https://sns-img-qc.xhscdn.com/a.jpg", + "https://sns-img-qc.xhscdn.com/b.jpg", + ] + + +def test_parse_xiaohongshu_article_falls_back_to_meta_text(): + html = """ + + + """ + + article = parse_xiaohongshu_article_html( + html, + "https://www.xiaohongshu.com/explore/fallback", + ) + + assert article.article_id == "fallback" + assert article.title == "备用标题" + assert article.content_text == "备用正文" + + +def test_parse_xiaohongshu_discovery_html_extracts_note_cards(): + html = """ + + """ + + items = parse_xiaohongshu_discovery_html( + html, + source_url="https://www.xiaohongshu.com/search_result?keyword=AI", + limit=10, + ) + + assert len(items) == 1 + assert items[0].platform == "xiaohongshu" + assert items[0].article_id == "note-a" + assert items[0].title == "AI工具分享" + assert items[0].url == "https://www.xiaohongshu.com/explore/note-a" + assert items[0].author_name == "作者A" + assert items[0].cover_url == "https://sns-img-qc.xhscdn.com/cover.jpg" diff --git a/backend/tests/test_article_routes.py b/backend/tests/test_article_routes.py new file mode 100644 index 0000000000000000000000000000000000000000..bd08e0476b87f5440e20629ab02f673dda3968f9 --- /dev/null +++ b/backend/tests/test_article_routes.py @@ -0,0 +1,122 @@ +from contextlib import asynccontextmanager + +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from app import create_app + + +@asynccontextmanager +async def noop_lifespan(app: FastAPI): + yield + + +class FakeService: + def generate_from_url(self, **kwargs): + return {"task_id": kwargs.get("task_id") or "task-1", "article_item_id": 1} + + def search(self, platform, keyword, limit=20): + return {"platform": platform, "keyword": keyword, "status": "ok", "message": "", "items": []} + + def create_subscription(self, platform, subscription_type, query, label=""): + return {"id": 1, "platform": platform, "type": subscription_type, "query": query, "label": label} + + def list_subscriptions(self): + return [] + + def refresh_subscription(self, subscription_id, limit=20): + return {"subscription_id": subscription_id, "count": 0, "items": []} + + def list_items(self, subscription_id=None): + return [] + + def summarize_item(self, item_id, **kwargs): + return {"task_id": "task-1", "article_item_id": item_id} + + def fetch_only_from_url(self, url, platform): + return {"id": 1, "platform": platform, "url": url, "title": "Test Title"} + + def import_only_content(self, url, platform, title, content_text, author_name=""): + return {"id": 1, "platform": platform, "url": url, "title": title, "content_text": content_text} + + + +def app_with_fake_service(monkeypatch): + from app.routers import article + + monkeypatch.setattr(article, "ArticleService", lambda: FakeService()) + return TestClient(create_app(lifespan=noop_lifespan)) + + +def test_generate_article_route(monkeypatch): + client = app_with_fake_service(monkeypatch) + + response = client.post( + "/api/articles/generate", + json={ + "url": "https://mp.weixin.qq.com/s/a", + "platform": "wechat_mp", + "provider_id": "p", + "model_name": "m", + "style": "minimal", + }, + ) + + assert response.status_code == 200 + assert response.json()["data"]["task_id"] == "task-1" + + +def test_article_search_route(monkeypatch): + client = app_with_fake_service(monkeypatch) + + response = client.get("/api/articles/search?platform=xiaohongshu&keyword=AI") + + assert response.status_code == 200 + assert response.json()["data"]["keyword"] == "AI" + + +def test_article_subscription_routes(monkeypatch): + client = app_with_fake_service(monkeypatch) + + created = client.post( + "/api/article_subscriptions", + json={"platform": "wechat_mp", "type": "publisher", "query": "账号", "label": "账号"}, + ) + refreshed = client.post("/api/article_subscriptions/1/refresh") + + assert created.status_code == 200 + assert created.json()["data"]["query"] == "账号" + assert refreshed.json()["data"]["subscription_id"] == 1 + + +def test_fetch_article_route(monkeypatch): + client = app_with_fake_service(monkeypatch) + + response = client.post( + "/api/articles/fetch", + json={ + "url": "https://mp.weixin.qq.com/s/a", + "platform": "wechat_mp", + }, + ) + + assert response.status_code == 200 + assert response.json()["data"]["url"] == "https://mp.weixin.qq.com/s/a" + + +def test_import_article_route(monkeypatch): + client = app_with_fake_service(monkeypatch) + + response = client.post( + "/api/articles/import", + json={ + "url": "https://mp.weixin.qq.com/s/a", + "platform": "wechat_mp", + "title": "My Title", + "content_text": "This is a long text content for test.", + }, + ) + + assert response.status_code == 200 + assert response.json()["data"]["title"] == "My Title" + diff --git a/backend/tests/test_article_service.py b/backend/tests/test_article_service.py new file mode 100644 index 0000000000000000000000000000000000000000..fe22cf838838a6eefae1dcdecc3f3300d3455c46 --- /dev/null +++ b/backend/tests/test_article_service.py @@ -0,0 +1,155 @@ +import importlib +import json + +from app.article_fetchers.base import ArticleContent + + +def _load_article_service(tmp_path, monkeypatch): + monkeypatch.setenv("DATABASE_URL", f"sqlite:///{tmp_path / 'articles.db'}") + monkeypatch.setenv("NOTE_OUTPUT_DIR", str(tmp_path / "notes")) + import app.db.engine as engine + import app.db.models.articles as article_models + import app.db.init_db as init_db + + importlib.reload(engine) + importlib.reload(article_models) + importlib.reload(init_db) + init_db.init_db() + + import app.db.article_dao as article_dao + import app.services.article as article_service + + importlib.reload(article_dao) + return importlib.reload(article_service) + + +class FakeGPT: + total_tokens = 42 + + def summarize(self, source): + text = "\n".join(segment.text for segment in source.segment) + assert "正文内容" in text + return "# 总结\n\n- 要点" + + +class FakeFetcher: + platform = "wechat_mp" + + def fetch(self, url): + return ArticleContent( + platform="wechat_mp", + url=url, + article_id="article-1", + title="文章标题", + author_name="作者", + content_text="正文内容", + cover_url="https://example.com/cover.jpg", + ) + + def search(self, keyword: str, limit: int = 20): + return [] + + def fetch_publisher(self, query: str, limit: int = 20): + return [] + + +class SearchFetcher(FakeFetcher): + platform = "xiaohongshu" + + def search(self, keyword: str, limit: int = 20): + return [ + ArticleContent( + platform="xiaohongshu", + url="https://www.xiaohongshu.com/explore/search-1", + article_id="search-1", + title=f"{keyword} 搜索结果", + author_name="作者", + content_text="正文", + ) + ] + + def fetch_publisher(self, query: str, limit: int = 20): + return [ + ArticleContent( + platform="xiaohongshu", + url="https://www.xiaohongshu.com/explore/pub-1", + article_id="pub-1", + title=f"{query} 发布者结果", + author_name=query, + author_id=query, + content_text="正文", + ) + ] + + +def test_generate_from_url_saves_note_json(tmp_path, monkeypatch): + article_service = _load_article_service(tmp_path, monkeypatch) + service = article_service.ArticleService( + fetchers={"wechat_mp": FakeFetcher()}, + gpt_factory=lambda *_: FakeGPT(), + ) + + result = service.generate_from_url( + url="https://mp.weixin.qq.com/s/a", + platform="wechat_mp", + provider_id="provider", + model_name="model", + style="minimal", + extras="", + task_id="task-1", + ) + + saved = json.loads((tmp_path / "notes" / "task-1.json").read_text(encoding="utf-8")) + status = json.loads((tmp_path / "notes" / "task-1.status.json").read_text(encoding="utf-8")) + assert result["task_id"] == "task-1" + assert saved["markdown"] == "# 总结\n\n- 要点" + assert saved["transcript"]["full_text"] == "正文内容" + assert saved["audio_meta"]["title"] == "文章标题" + assert saved["audio_meta"]["platform"] == "wechat_mp" + assert saved["audio_meta"]["video_id"] == "article-1" + assert saved["total_tokens"] == 42 + assert status["status"] == "SUCCESS" + + +def test_search_by_keyword_persists_results(tmp_path, monkeypatch): + article_service = _load_article_service(tmp_path, monkeypatch) + service = article_service.ArticleService( + fetchers={"xiaohongshu": SearchFetcher()}, + gpt_factory=lambda *_: FakeGPT(), + ) + + result = service.search(platform="xiaohongshu", keyword="AI", limit=10) + + assert result["status"] == "ok" + assert result["items"][0]["title"] == "AI 搜索结果" + assert len(article_service.list_article_items()) == 1 + + +def test_refresh_keyword_subscription_links_items(tmp_path, monkeypatch): + article_service = _load_article_service(tmp_path, monkeypatch) + service = article_service.ArticleService( + fetchers={"xiaohongshu": SearchFetcher()}, + gpt_factory=lambda *_: FakeGPT(), + ) + subscription = article_service.create_subscription("xiaohongshu", "keyword", "AI", "AI") + + result = service.refresh_subscription(subscription.id) + + assert result["subscription_id"] == subscription.id + assert result["count"] == 1 + assert article_service.list_subscriptions()[0].query == "AI" + assert article_service.list_article_items(subscription_id=subscription.id)[0].title == "AI 搜索结果" + + +def test_refresh_publisher_subscription_links_items(tmp_path, monkeypatch): + article_service = _load_article_service(tmp_path, monkeypatch) + service = article_service.ArticleService( + fetchers={"xiaohongshu": SearchFetcher()}, + gpt_factory=lambda *_: FakeGPT(), + ) + subscription = article_service.create_subscription("xiaohongshu", "publisher", "作者", "作者") + + result = service.refresh_subscription(subscription.id) + + assert result["count"] == 1 + assert result["items"][0]["title"] == "作者 发布者结果" diff --git a/backend/tests/test_browser_cookie_service.py b/backend/tests/test_browser_cookie_service.py new file mode 100644 index 0000000000000000000000000000000000000000..df87d37a80830c3d1e9f8604cc9920e2c22a78d6 --- /dev/null +++ b/backend/tests/test_browser_cookie_service.py @@ -0,0 +1,60 @@ +import sys +from pathlib import Path +from types import SimpleNamespace + +BACKEND_ROOT = Path(__file__).resolve().parents[1] +if str(BACKEND_ROOT) not in sys.path: + sys.path.insert(0, str(BACKEND_ROOT)) + +from app.services.browser_cookie import BrowserCookieError, sync_browser_cookie +from app.services.cookie_manager import CookieConfigManager + + +def test_sync_browser_cookie_filters_platform_domain_and_persists(tmp_path, monkeypatch): + manager = CookieConfigManager(filepath=str(tmp_path / "downloader.json")) + + def fake_extract(browser): + assert browser == "chrome" + return [ + SimpleNamespace(domain=".youtube.com", name="SID", value="abc"), + SimpleNamespace(domain="www.youtube.com", name="HSID", value="def"), + SimpleNamespace(domain=".douyin.com", name="sessionid", value="skip"), + SimpleNamespace(domain=".youtube.com", name="empty", value=""), + ] + + monkeypatch.setattr("app.services.browser_cookie._extract_cookies_from_browser", fake_extract) + + result = sync_browser_cookie("youtube", "chrome", manager=manager) + + assert result == { + "platform": "youtube", + "browser": "chrome", + "cookie": "SID=abc; HSID=def", + "count": 2, + } + assert manager.get("youtube") == "SID=abc; HSID=def" + assert manager.get_browser("youtube") == "chrome" + + +def test_sync_browser_cookie_raises_when_browser_has_no_platform_cookie(tmp_path, monkeypatch): + manager = CookieConfigManager(filepath=str(tmp_path / "downloader.json")) + opened = [] + + monkeypatch.setattr( + "app.services.browser_cookie._extract_cookies_from_browser", + lambda browser: [SimpleNamespace(domain=".example.com", name="SID", value="abc")], + ) + monkeypatch.setattr( + "app.services.browser_cookie._open_url_in_browser", + lambda url, browser: opened.append((url, browser)) or True, + ) + + try: + sync_browser_cookie("bilibili", "safari", manager=manager) + except BrowserCookieError as exc: + assert "未找到 bilibili 对应的浏览器 Cookie" in str(exc) + assert "已打开B站页面" in str(exc) + else: + raise AssertionError("Expected BrowserCookieError") + + assert opened == [("https://www.bilibili.com/", "safari")] diff --git a/backend/tests/test_douyin_downloader.py b/backend/tests/test_douyin_downloader.py new file mode 100644 index 0000000000000000000000000000000000000000..322a0717ba461508edcf7da62f0281be3bcb5122 --- /dev/null +++ b/backend/tests/test_douyin_downloader.py @@ -0,0 +1,210 @@ +import json +import sys +from pathlib import Path + +import pytest + +BACKEND_ROOT = Path(__file__).resolve().parents[1] +if str(BACKEND_ROOT) not in sys.path: + sys.path.insert(0, str(BACKEND_ROOT)) + +from app.downloaders import douyin_downloader +from app.downloaders.douyin_downloader import DouyinDownloader, DouyinResolveError + + +AWEME_ID = "7345492945006595379" +VIDEO_ID = "v0200fg10000cq123abc" + + +def _router_html(item: dict) -> str: + payload = {"loaderData": {"anything": {"item_list": [item]}}} + return ( + "" + ) + + +def _video_item() -> dict: + return { + "aweme_id": AWEME_ID, + "desc": "测试视频 #知识", + "duration": 123456, + "author": {"nickname": "作者A"}, + "video": { + "play_addr": { + "uri": VIDEO_ID, + "url_list": ["https://example.com/playwm/?video_id=watermarked"], + }, + "cover": {"url_list": ["https://example.com/cover.jpg"]}, + }, + "text_extra": [{"hashtag_name": "知识"}], + } + + +class DummyResponse: + def __init__( + self, + *, + text="", + content=b"", + url="https://www.iesdouyin.com/share/video/1/", + ): + self.text = text + self.content = content + self.url = url + self.status_code = 200 + + def raise_for_status(self): + return None + + def iter_content(self, chunk_size=8192): + yield self.content + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + +def test_expand_share_url_extracts_douyin_url_from_share_text(): + share_text = ( + "7.43 复制打开抖音,看看这个视频 " + "https://v.douyin.com/abc123/ 复制此链接,打开Dou音搜索" + ) + + assert douyin_downloader.expand_share_url(share_text) == "https://v.douyin.com/abc123" + + +@pytest.mark.parametrize( + ("url", "expected"), + [ + ( + f"https://www.douyin.com/video/{AWEME_ID}", + f"https://www.iesdouyin.com/share/video/{AWEME_ID}/", + ), + ( + f"https://www.douyin.com/note/{AWEME_ID}", + f"https://www.iesdouyin.com/share/note/{AWEME_ID}/", + ), + ( + f"https://www.douyin.com/search/agent?aid=abc&modal_id={AWEME_ID}&type=general", + f"https://www.iesdouyin.com/share/video/{AWEME_ID}/", + ), + ( + f"https://www.douyin.com/search/%E7%9F%A5%E8%AF%86?item_ids={AWEME_ID}", + f"https://www.iesdouyin.com/share/video/{AWEME_ID}/", + ), + ], +) +def test_normalize_to_share_page_converts_www_urls(url, expected): + assert douyin_downloader.normalize_to_share_page(url) == expected + + +def test_parse_share_page_html_reads_router_data_video_metadata(): + html = _router_html(_video_item()) + + meta = douyin_downloader.parse_share_page_html( + html, + f"https://www.iesdouyin.com/share/video/{AWEME_ID}/", + "https://v.douyin.com/abc123", + ) + + assert meta.aweme_id == AWEME_ID + assert meta.title == "测试视频 #知识" + assert meta.author == "作者A" + assert meta.cover_url == "https://example.com/cover.jpg" + assert meta.duration == pytest.approx(123.456) + assert meta.download_url == ( + f"https://aweme.snssdk.com/aweme/v1/play/?video_id={VIDEO_ID}&ratio=720p&line=0" + ) + assert meta.tags == ["知识"] + + +def test_resolve_douyin_share_fetches_share_page_for_search_modal(monkeypatch): + html = _router_html(_video_item()) + requested_urls = [] + search_url = ( + "https://www.douyin.com/search/agent" + f"?aid=4848bd6d-24bb-480e-aec8-b3ba55799c17&modal_id={AWEME_ID}&type=general" + ) + + class DummySession: + headers = {} + + def get(self, url, allow_redirects=True, timeout=30): + requested_urls.append(url) + return DummyResponse( + text=html, + url=f"https://www.iesdouyin.com/share/video/{AWEME_ID}/", + ) + + monkeypatch.setattr(douyin_downloader, "_session", lambda: DummySession()) + + meta = douyin_downloader.resolve_douyin_share(search_url) + + assert requested_urls == [f"https://www.iesdouyin.com/share/video/{AWEME_ID}/"] + assert meta.aweme_id == AWEME_ID + assert meta.download_url == ( + f"https://aweme.snssdk.com/aweme/v1/play/?video_id={VIDEO_ID}&ratio=720p&line=0" + ) + + +def test_parse_share_page_html_raises_actionable_error_without_ssr_data(): + with pytest.raises(DouyinResolveError, match="分享页未找到"): + douyin_downloader.parse_share_page_html( + "", + f"https://www.iesdouyin.com/share/video/{AWEME_ID}/", + "https://v.douyin.com/abc123", + ) + + +def test_downloader_download_uses_share_page_video_and_extracts_audio( + monkeypatch, tmp_path: Path +): + html = _router_html(_video_item()) + + class DummySession: + headers = {} + + def get(self, url, allow_redirects=True, timeout=30): + assert url == "https://v.douyin.com/abc123" + return DummyResponse( + text=html, + url=f"https://www.iesdouyin.com/share/video/{AWEME_ID}/", + ) + + requested_urls = [] + + def fake_session(): + return DummySession() + + def fake_get(url, headers=None, stream=False, timeout=None, **kwargs): + requested_urls.append(url) + return DummyResponse(content=b"fake mp4") + + def fake_run(cmd, check, stdout, stderr): + output_path = Path(cmd[-1]) + output_path.write_bytes(b"fake mp3") + + monkeypatch.setattr(douyin_downloader, "_session", fake_session) + monkeypatch.setattr(douyin_downloader.requests, "get", fake_get) + monkeypatch.setattr(douyin_downloader.subprocess, "run", fake_run) + + result = DouyinDownloader().download( + "https://v.douyin.com/abc123/", + output_dir=str(tmp_path), + ) + + assert result.video_id == AWEME_ID + assert result.platform == "douyin" + assert result.title == "测试视频 #知识" + assert result.duration == pytest.approx(123.456) + assert result.cover_url == "https://example.com/cover.jpg" + assert result.raw_info["tags"] == ["知识"] + assert Path(result.file_path).read_bytes() == b"fake mp3" + assert Path(result.video_path).read_bytes() == b"fake mp4" + assert requested_urls == [ + f"https://aweme.snssdk.com/aweme/v1/play/?video_id={VIDEO_ID}&ratio=720p&line=0" + ] diff --git a/backend/tests/test_hot_videos_route.py b/backend/tests/test_hot_videos_route.py new file mode 100644 index 0000000000000000000000000000000000000000..b49b332240afb4da8df476e397f758c7439c7058 --- /dev/null +++ b/backend/tests/test_hot_videos_route.py @@ -0,0 +1,80 @@ +import sys +from contextlib import asynccontextmanager +from pathlib import Path + +from fastapi import FastAPI +from fastapi.testclient import TestClient + +BACKEND_ROOT = Path(__file__).resolve().parents[1] +if str(BACKEND_ROOT) not in sys.path: + sys.path.insert(0, str(BACKEND_ROOT)) + +from app import create_app +from app.services.hot_videos import HotVideoItem, PlatformHotVideoResult + + +@asynccontextmanager +async def noop_lifespan(app: FastAPI): + yield + + +def test_hot_videos_route_returns_normalized_payload(monkeypatch): + from app.routers import hot_videos as route + + def fake_payload(platform="all", limit=12, force=False): + return { + "platform": platform, + "limit": limit, + "generated_at": "2026-06-08T09:30:00+08:00", + "platforms": [ + PlatformHotVideoResult( + platform="bilibili", + status="ok", + message="", + items=[ + HotVideoItem( + id="BV1route", + platform="bilibili", + title="路由测试", + url="https://www.bilibili.com/video/BV1route", + rank=1, + source="bilibili_popular", + ) + ], + ).to_dict() + ], + } + + monkeypatch.setattr(route, "fetch_hot_video_payload", fake_payload) + app = create_app(lifespan=noop_lifespan) + client = TestClient(app) + + response = client.get("/api/hot_videos?platform=bilibili&limit=3") + + assert response.status_code == 200 + body = response.json() + assert body["code"] == 0 + assert body["data"]["platform"] == "bilibili" + assert body["data"]["limit"] == 3 + assert ( + body["data"]["platforms"][0]["items"][0]["url"] + == "https://www.bilibili.com/video/BV1route" + ) + + +def test_hot_videos_route_returns_business_error_for_invalid_platform(monkeypatch): + from app.routers import hot_videos as route + + def fake_payload(platform="all", limit=12, force=False): + raise ValueError("不支持的热点平台: instagram") + + monkeypatch.setattr(route, "fetch_hot_video_payload", fake_payload) + app = create_app(lifespan=noop_lifespan) + client = TestClient(app) + + response = client.get("/api/hot_videos?platform=instagram") + + assert response.status_code == 200 + body = response.json() + assert body["code"] == 400 + assert "不支持的热点平台" in body["msg"] diff --git a/backend/tests/test_hot_videos_service.py b/backend/tests/test_hot_videos_service.py new file mode 100644 index 0000000000000000000000000000000000000000..fcea8ec98dca9f6bda267401afc235b67cd8c42f --- /dev/null +++ b/backend/tests/test_hot_videos_service.py @@ -0,0 +1,411 @@ +import sys +import threading +from pathlib import Path + +import pytest + +BACKEND_ROOT = Path(__file__).resolve().parents[1] +if str(BACKEND_ROOT) not in sys.path: + sys.path.insert(0, str(BACKEND_ROOT)) + +from app.services import hot_videos +from app.services.hot_videos import ( + HotVideoItem, + PlatformHotVideoResult, + _format_bilibili_views, + _fetch_bilibili_hot, + _map_douyin_hot_items, + _map_bilibili_reader_markdown_items, + _map_bilibili_popular_items, + _map_youtube_trending_html, + fetch_hot_video_payload, + fetch_hot_videos, +) + + +def test_format_bilibili_views_uses_chinese_units(): + assert _format_bilibili_views(9999) == "9999播放" + assert _format_bilibili_views(10000) == "1.0万播放" + assert _format_bilibili_views(1250000) == "125.0万播放" + + +def test_map_bilibili_popular_items_normalizes_expected_fields(): + payload = { + "data": { + "list": [ + { + "bvid": "BV1abc123", + "title": "测试热门视频", + "pic": "//i0.hdslb.com/bfs/archive/cover.jpg", + "owner": {"name": "测试 UP"}, + "stat": {"view": 123456}, + }, + { + "bvid": "", + "title": "没有 bvid 的条目会被跳过", + }, + ] + } + } + + items = _map_bilibili_popular_items(payload, limit=5) + + assert items == [ + HotVideoItem( + id="BV1abc123", + platform="bilibili", + title="测试热门视频", + url="https://www.bilibili.com/video/BV1abc123", + cover_url="https://i0.hdslb.com/bfs/archive/cover.jpg", + author="测试 UP", + rank=1, + hot_score="12.3万播放", + source="bilibili_popular", + ) + ] + + +def test_map_bilibili_reader_markdown_items_extracts_popular_page_links(): + markdown = """ +[![Image 1](https://i1.hdslb.com/bfs/archive/cover.jpg@412w_232h_1c_!web-popular.avif)](https://www.bilibili.com/video/BV1xkE46PE59) + +参加了九次高考的考生采访 + +自来卷三木 + +275.2万 1847 +""" + + items = _map_bilibili_reader_markdown_items(markdown, limit=3) + + assert items == [ + HotVideoItem( + id="BV1xkE46PE59", + platform="bilibili", + title="参加了九次高考的考生采访", + url="https://www.bilibili.com/video/BV1xkE46PE59", + cover_url="https://i1.hdslb.com/bfs/archive/cover.jpg@412w_232h_1c_!web-popular.avif", + author="自来卷三木", + rank=1, + hot_score="275.2万播放", + source="bilibili_popular_reader", + ) + ] + + +def test_fetch_bilibili_hot_falls_back_to_reader_when_api_fails(monkeypatch): + markdown = """ +[![Image 1](https://i1.hdslb.com/bfs/archive/cover.jpg)](https://www.bilibili.com/video/BV1xkE46PE59) + +参加了九次高考的考生采访 + +自来卷三木 + +275.2万 1847 +""" + + class FakeResponse: + text = markdown + + def raise_for_status(self): + return None + + class FakeSession: + headers = {} + + def __init__(self): + self.calls = [] + + def get(self, url, **kwargs): + self.calls.append(url) + if "api.bilibili.com" in url: + raise RuntimeError("tls eof") + return FakeResponse() + + session = FakeSession() + monkeypatch.setattr(hot_videos, "_session", lambda: session) + + result = _fetch_bilibili_hot(limit=2) + + assert result.status == "ok" + assert result.message == "官方热点接口暂不可用,已切换备用热点源" + assert result.items[0].source == "bilibili_popular_reader" + assert result.items[0].url == "https://www.bilibili.com/video/BV1xkE46PE59" + assert any("api.bilibili.com" in call for call in session.calls) + assert any("r.jina.ai" in call for call in session.calls) + + +def test_fetch_bilibili_hot_uses_snapshot_when_live_sources_fail(monkeypatch): + class FakeSession: + headers = {} + + def get(self, url, **kwargs): + raise RuntimeError(f"blocked: {url}") + + monkeypatch.setattr(hot_videos, "_session", lambda: FakeSession()) + + result = _fetch_bilibili_hot(limit=2) + + assert result.status == "ok" + assert result.message == "实时热点源暂不可用,已显示最近热门快照" + assert len(result.items) == 2 + assert {item.platform for item in result.items} == {"bilibili"} + assert {item.source for item in result.items} == {"bilibili_popular_snapshot"} + assert all(item.url.startswith("https://www.bilibili.com/video/BV") for item in result.items) + + +def test_fetch_all_keeps_successful_platform_when_another_platform_fails(monkeypatch): + def fake_bilibili(limit): + return PlatformHotVideoResult( + platform="bilibili", + status="ok", + message="", + items=[ + HotVideoItem( + id="BV1ok", + platform="bilibili", + title="可用热门", + url="https://www.bilibili.com/video/BV1ok", + rank=1, + source="bilibili_popular", + ) + ], + ) + + def fake_youtube(limit): + raise RuntimeError("network blocked") + + monkeypatch.setattr( + hot_videos, + "HOT_FETCHERS", + { + "bilibili": fake_bilibili, + "youtube": fake_youtube, + }, + ) + hot_videos.clear_hot_video_cache() + + payload = fetch_hot_video_payload(platform="all", limit=3) + + assert payload["platform"] == "all" + assert [p["platform"] for p in payload["platforms"]] == ["bilibili", "youtube"] + assert payload["platforms"][0]["status"] == "ok" + assert payload["platforms"][0]["items"][0]["title"] == "可用热门" + assert payload["platforms"][1]["status"] == "error" + assert "network blocked" in payload["platforms"][1]["message"] + + +def test_fetch_all_runs_platform_fetchers_concurrently(monkeypatch): + lock = threading.Lock() + started: list[str] = [] + both_started = threading.Event() + + def make_fetcher(name: str): + def fake_fetcher(limit): + with lock: + started.append(name) + if len(started) == 2: + both_started.set() + if not both_started.wait(timeout=1): + raise AssertionError("platform fetchers did not overlap") + return PlatformHotVideoResult( + platform=name, + status="ok", + message="", + items=[ + HotVideoItem( + id=f"{name}-1", + platform=name, + title=f"{name} 热点", + url=f"https://example.com/{name}", + rank=1, + ) + ], + ) + + return fake_fetcher + + monkeypatch.setattr( + hot_videos, + "HOT_FETCHERS", + { + "bilibili": make_fetcher("bilibili"), + "youtube": make_fetcher("youtube"), + }, + ) + hot_videos.clear_hot_video_cache() + + payload = fetch_hot_video_payload(platform="all", limit=1, force=True) + + assert [item["status"] for item in payload["platforms"]] == ["ok", "ok"] + assert [item["platform"] for item in payload["platforms"]] == ["bilibili", "youtube"] + + +def test_fetch_hot_videos_rejects_unknown_platform(): + with pytest.raises(ValueError, match="不支持的热点平台"): + fetch_hot_videos(platform="instagram", limit=3) + + +def test_cache_returns_first_payload_inside_ttl(monkeypatch): + calls = [] + + def fake_bilibili(limit): + calls.append(limit) + return PlatformHotVideoResult( + platform="bilibili", + status="ok", + message="", + items=[ + HotVideoItem( + id=f"BV{len(calls)}", + platform="bilibili", + title=f"第 {len(calls)} 次", + url=f"https://www.bilibili.com/video/BV{len(calls)}", + rank=1, + source="bilibili_popular", + ) + ], + ) + + monkeypatch.setattr(hot_videos, "HOT_FETCHERS", {"bilibili": fake_bilibili}) + hot_videos.clear_hot_video_cache() + + first = fetch_hot_video_payload(platform="bilibili", limit=1) + second = fetch_hot_video_payload(platform="bilibili", limit=1) + + assert len(calls) == 1 + assert first == second + assert second["platforms"][0]["items"][0]["title"] == "第 1 次" + + +def test_map_youtube_trending_html_extracts_video_renderers(): + html = """ + + """ + + items = _map_youtube_trending_html(html, limit=5) + + assert items[0] == HotVideoItem( + id="abc123XYZ09", + platform="youtube", + title="YouTube 热门", + url="https://www.youtube.com/watch?v=abc123XYZ09", + cover_url="https://i.ytimg.com/vi/abc123XYZ09/hqdefault.jpg", + author="频道名", + rank=1, + hot_score="12万次观看", + source="youtube_trending", + ) + + +def test_map_douyin_hot_items_extracts_detail_aweme_entries(): + payload = { + "data": { + "word_list": [ + { + "word": "热点话题", + "hot_value": 123456, + "aweme_infos": [ + { + "aweme_id": "7123456789012345678", + "desc": "抖音热点视频", + "author": {"nickname": "创作者"}, + "video": {"cover": {"url_list": ["https://example.com/cover.jpg"]}}, + } + ], + } + ] + } + } + + items = _map_douyin_hot_items(payload, limit=5) + + assert items[0] == HotVideoItem( + id="7123456789012345678", + platform="douyin", + title="抖音热点视频", + url="https://www.douyin.com/video/7123456789012345678", + cover_url="https://example.com/cover.jpg", + author="创作者", + rank=1, + hot_score="123456热度", + source="douyin_hot_search", + ) + + +def test_fetch_newsnow_hot_maps_correctly(monkeypatch): + from app.services.hot_videos import _fetch_newsnow_hot + + mock_payload = { + "status": "cache", + "id": "zhihu", + "items": [ + { + "id": "12345", + "title": "测试知乎标题", + "url": "https://zhihu.com/question/12345", + "author": "知乎网友", + "extra": {"info": "100万热度"}, + }, + { + "id": "67890", + "title": "", # empty title should be skipped + "url": "https://zhihu.com/question/67890", + }, + ], + } + + class FakeResponse: + status_code = 200 + + def raise_for_status(self): + return None + + def json(self): + return mock_payload + + class FakeSession: + headers = {} + + def get(self, url, params=None, **kwargs): + assert "newsnow" in url + assert params == {"id": "zhihu"} + return FakeResponse() + + monkeypatch.setattr(hot_videos, "_session", lambda: FakeSession()) + + result = _fetch_newsnow_hot("zhihu", limit=5) + assert result.status == "ok" + assert result.platform == "zhihu" + assert len(result.items) == 1 + assert result.items[0] == HotVideoItem( + id="12345", + platform="zhihu", + title="测试知乎标题", + url="https://zhihu.com/question/12345", + cover_url="", + author="知乎网友", + rank=1, + hot_score="100万热度", + source="newsnow", + ) + + +def test_fetch_newsnow_hot_handles_failure(monkeypatch): + from app.services.hot_videos import _fetch_newsnow_hot + + class FakeSession: + headers = {} + + def get(self, url, params=None, **kwargs): + raise RuntimeError("API timeout") + + monkeypatch.setattr(hot_videos, "_session", lambda: FakeSession()) + + result = _fetch_newsnow_hot("zhihu", limit=5) + assert result.status == "error" + assert "API timeout" in result.message + assert len(result.items) == 0 + diff --git a/backend/tests/test_model_service.py b/backend/tests/test_model_service.py new file mode 100644 index 0000000000000000000000000000000000000000..8aa80f586cade69354ee77bbb40aaed743b19b4f --- /dev/null +++ b/backend/tests/test_model_service.py @@ -0,0 +1,57 @@ +import sys +from pathlib import Path +from types import SimpleNamespace + +BACKEND_ROOT = Path(__file__).resolve().parents[1] +if str(BACKEND_ROOT) not in sys.path: + sys.path.insert(0, str(BACKEND_ROOT)) + +from app.services.model import ModelService +from app.services.provider import ProviderService + + +def _provider(provider_id: str): + return { + "id": provider_id, + "name": "DeepSeek", + "api_key": "sk-test", + "base_url": "https://api.deepseek.com", + } + + +def test_get_all_models_by_id_accepts_plain_model_list(monkeypatch): + monkeypatch.setattr(ProviderService, "get_provider_by_id", _provider) + monkeypatch.setattr( + ModelService, + "get_model_list", + lambda provider_id, verbose=False: [ + {"id": "deepseek-chat", "object": "model"}, + {"id": "deepseek-reasoner", "object": "model"}, + ], + ) + + result = ModelService.get_all_models_by_id("deepseek") + + assert result == { + "models": [ + {"id": "deepseek-chat", "object": "model"}, + {"id": "deepseek-reasoner", "object": "model"}, + ] + } + + +def test_get_all_models_by_id_accepts_openai_page_data(monkeypatch): + monkeypatch.setattr(ProviderService, "get_provider_by_id", _provider) + monkeypatch.setattr( + ModelService, + "get_model_list", + lambda provider_id, verbose=False: SimpleNamespace( + data=[ + SimpleNamespace(model_dump=lambda: {"id": "gpt-4o-mini", "object": "model"}), + ] + ), + ) + + result = ModelService.get_all_models_by_id("openai") + + assert result == {"models": [{"id": "gpt-4o-mini", "object": "model"}]} diff --git a/backend/tests/test_note_helper.py b/backend/tests/test_note_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..1dbd282cf78bf15e44ace00191eab27e9e1a4c9c --- /dev/null +++ b/backend/tests/test_note_helper.py @@ -0,0 +1,79 @@ +import importlib.util +import pathlib +import unittest + + +ROOT = pathlib.Path(__file__).resolve().parents[1] +MODULE_PATH = ROOT / "app" / "utils" / "note_helper.py" +spec = importlib.util.spec_from_file_location("note_helper", MODULE_PATH) +if spec is None or spec.loader is None: + raise ImportError("note_helper module spec not found") +note_helper = importlib.util.module_from_spec(spec) +spec.loader.exec_module(note_helper) + + +class TestNoteHelper(unittest.TestCase): + def test_prepend_source_link_adds_header_at_top(self): + source_url = "https://www.bilibili.com/video/BV1xx411c7mD" + markdown = "## 标题\n\n内容" + + result = note_helper.prepend_source_link(markdown, source_url) + + self.assertTrue(result.startswith(f"> 来源链接:{source_url}\n\n")) + self.assertIn("## 标题", result) + + def test_prepend_source_link_does_not_duplicate_when_header_exists(self): + source_url = "https://www.youtube.com/watch?v=abc123" + markdown = f"> 来源链接:{source_url}\n\n## 标题\n\n内容" + + result = note_helper.prepend_source_link(markdown, source_url) + + self.assertEqual(result, markdown) + + def test_normalize_toc_strips_heading_markers_in_items(self): + markdown = "## 目录\n\n- ## 1. 章节一\n- ## 2. 章节二\n\n## 1. 章节一\n正文" + + result = note_helper.normalize_toc(markdown) + + self.assertIn("- 1. 章节一", result) + self.assertIn("- 2. 章节二", result) + self.assertNotIn("- ## ", result) + # 正文标题不受影响 + self.assertIn("\n## 1. 章节一\n", result) + + def test_normalize_toc_strips_heading_marker_inside_bold(self): + markdown = "## 目录\n\n- **## 4. 应用矩阵**\n\n## 4. 应用矩阵\n正文" + + result = note_helper.normalize_toc(markdown) + + # 加粗保留,只剥标题标记 + self.assertIn("- **4. 应用矩阵**", result) + + def test_normalize_toc_keeps_sub_items_and_strips_their_markers(self): + markdown = ( + "## 目录\n\n" + "- 章节一\n" + " - 子项A\n" + " - ## 子项B\n" + "- 章节二\n\n" + "## 章节一\n正文" + ) + + result = note_helper.normalize_toc(markdown) + + # 嵌套子项允许、缩进保留;子项里的标题标记同样剥掉 + self.assertIn(" - 子项A", result) + self.assertIn(" - 子项B", result) + self.assertNotIn("- ## 子项B", result) + self.assertIn("- 章节一", result) + self.assertIn("- 章节二", result) + + def test_normalize_toc_noop_without_toc_section(self): + markdown = "# 标题\n\n- 普通列表 ## 不该被动\n正文" + + self.assertEqual(note_helper.normalize_toc(markdown), markdown) + self.assertIsNone(note_helper.normalize_toc(None)) + + +if __name__ == "__main__": + unittest.main() diff --git a/backend/tests/test_openai_client.py b/backend/tests/test_openai_client.py new file mode 100644 index 0000000000000000000000000000000000000000..b3314b1ffbdba2d65ce5b9a126db08c18e9f6965 --- /dev/null +++ b/backend/tests/test_openai_client.py @@ -0,0 +1,27 @@ +import sys +from pathlib import Path + +BACKEND_ROOT = Path(__file__).resolve().parents[1] +if str(BACKEND_ROOT) not in sys.path: + sys.path.insert(0, str(BACKEND_ROOT)) + +from app.services.proxy_config_manager import ProxyConfigManager +from app.utils.openai_client import build_openai_client + + +def test_build_openai_client_ignores_invalid_no_proxy_env(monkeypatch): + monkeypatch.setenv( + "NO_PROXY", + "127.0.0.1,localhost,::1,127.0.0.0/8,::1/128", + ) + monkeypatch.setenv( + "no_proxy", + "127.0.0.1,localhost,::1,127.0.0.0/8,::1/128", + ) + for key in ("HTTPS_PROXY", "https_proxy", "HTTP_PROXY", "http_proxy", "ALL_PROXY", "all_proxy"): + monkeypatch.delenv(key, raising=False) + monkeypatch.setattr(ProxyConfigManager, "get_proxy_url", lambda self: None) + + client = build_openai_client("sk-test", "https://api.deepseek.com") + + assert str(client.base_url).rstrip("/") == "https://api.deepseek.com" diff --git a/backend/tests/test_request_chunker.py b/backend/tests/test_request_chunker.py new file mode 100644 index 0000000000000000000000000000000000000000..f75732c2ae81049182e571898143af7d88225b46 --- /dev/null +++ b/backend/tests/test_request_chunker.py @@ -0,0 +1,97 @@ +import importlib.util +import pathlib +import unittest +from dataclasses import dataclass + +ROOT = pathlib.Path(__file__).resolve().parents[1] +MODULE_PATH = ROOT / "app" / "gpt" / "request_chunker.py" +spec = importlib.util.spec_from_file_location("request_chunker", MODULE_PATH) +if spec is None or spec.loader is None: + raise ImportError("request_chunker module spec not found") +request_chunker = importlib.util.module_from_spec(spec) +spec.loader.exec_module(request_chunker) +RequestChunker = request_chunker.RequestChunker + + +@dataclass +class DummySeg: + start: float + end: float + text: str + + +def build_messages(segments, image_urls, **_): + content = [{"type": "text", "text": "".join(s.text for s in segments)}] + for url in image_urls: + content.append({"type": "image_url", "image_url": {"url": url, "detail": "auto"}}) + return [{"role": "user", "content": content}] + + +def size_estimator(messages): + size = 0 + for part in messages[0]["content"]: + if part["type"] == "text": + size += len(part["text"]) + else: + size += len(part["image_url"]["url"]) + return size + + +class TestRequestChunker(unittest.TestCase): + def test_chunk_segments_preserves_order_and_content(self): + segments = [ + DummySeg(0, 1, "aaaa"), + DummySeg(1, 2, "bbbb"), + DummySeg(2, 3, "cccc"), + ] + chunker = RequestChunker(build_messages, max_bytes=8, size_estimator=size_estimator) + chunks = chunker.chunk(segments, []) + texts = ["".join(seg.text for seg in c.segments) for c in chunks] + self.assertEqual("".join(texts), "aaaabbbbcccc") + self.assertTrue(all(texts)) + + def test_chunk_images_distributed_across_batches(self): + segments = [DummySeg(0, 1, "aa")] + images = ["i" * 6, "j" * 6, "k" * 6] + chunker = RequestChunker(build_messages, max_bytes=10, size_estimator=size_estimator) + chunks = chunker.chunk(segments, images) + all_images = [img for c in chunks for img in c.image_urls] + self.assertEqual(all_images, images) + + def test_chunk_images_are_not_front_loaded_when_multiple_segment_chunks(self): + segments = [ + DummySeg(0, 1, "aaaaaa"), + DummySeg(1, 2, "bbbbbb"), + DummySeg(2, 3, "cccccc"), + ] + images = ["11111", "22222", "33333"] + chunker = RequestChunker(build_messages, max_bytes=12, size_estimator=size_estimator) + chunks = chunker.chunk(segments, images) + + self.assertGreaterEqual(len(chunks), 3) + image_counts = [len(c.image_urls) for c in chunks] + self.assertGreater(image_counts[1], 0) + self.assertGreater(image_counts[2], 0) + all_images = [img for c in chunks for img in c.image_urls] + self.assertEqual(all_images, images) + + def test_split_oversized_segment(self): + segments = [DummySeg(0, 1, "x" * 25)] + chunker = RequestChunker(build_messages, max_bytes=10, size_estimator=size_estimator) + chunks = chunker.chunk(segments, []) + combined = "".join(seg.text for c in chunks for seg in c.segments) + self.assertEqual(combined, "x" * 25) + + def test_group_texts_by_budget(self): + chunker = RequestChunker(build_messages, max_bytes=10, size_estimator=size_estimator) + + def build_text_messages(texts, *_args, **_kwargs): + content = [{"type": "text", "text": "".join(texts)}] + return [{"role": "user", "content": content}] + + groups = chunker.group_texts_by_budget(["aaaaa", "bbbbb", "ccccc"], build_text_messages) + self.assertEqual(groups, [["aaaaa", "bbbbb"], ["ccccc"]]) + + +if __name__ == "__main__": + unittest.main() diff --git a/backend/tests/test_screenshot_marker.py b/backend/tests/test_screenshot_marker.py new file mode 100644 index 0000000000000000000000000000000000000000..b4b8a2134d112ac8ea7df37569819133cb7271e8 --- /dev/null +++ b/backend/tests/test_screenshot_marker.py @@ -0,0 +1,49 @@ +import importlib.util +import pathlib +import unittest + + +ROOT = pathlib.Path(__file__).resolve().parents[1] +MODULE_PATH = ROOT / "app" / "utils" / "screenshot_marker.py" +spec = importlib.util.spec_from_file_location("screenshot_marker", MODULE_PATH) +if spec is None or spec.loader is None: + raise ImportError("screenshot_marker module spec not found") +screenshot_marker = importlib.util.module_from_spec(spec) +spec.loader.exec_module(screenshot_marker) +extract_screenshot_timestamps = screenshot_marker.extract_screenshot_timestamps +extract_content_timestamps = screenshot_marker.extract_content_timestamps +ensure_screenshot_markers = screenshot_marker.ensure_screenshot_markers + + +class TestScreenshotMarker(unittest.TestCase): + def test_extract_accepts_star_bracket_format(self): + markdown = "A\n*Screenshot-[01:02]\nB" + matches = extract_screenshot_timestamps(markdown) + self.assertEqual(matches, [("*Screenshot-[01:02]", 62)]) + + def test_extract_accepts_legacy_formats(self): + markdown = "*Screenshot-03:04 and Screenshot-[05:06]" + matches = extract_screenshot_timestamps(markdown) + self.assertEqual( + matches, + [ + ("*Screenshot-03:04", 184), + ("Screenshot-[05:06]", 306), + ], + ) + + def test_extract_content_timestamps_for_fallback(self): + markdown = "## A *Content-[00:12]\n## B *Content-[01:03]\n## C *Content-[01:03]" + matches = extract_content_timestamps(markdown) + self.assertEqual(matches, [12, 63]) + + def test_ensure_screenshot_markers_adds_duration_fallback(self): + markdown = "## A\ncontent" + with_markers = ensure_screenshot_markers(markdown, 120) + self.assertIn("*Screenshot-[00:30]", with_markers) + self.assertIn("*Screenshot-[01:00]", with_markers) + self.assertIn("*Screenshot-[01:30]", with_markers) + + +if __name__ == "__main__": + unittest.main() diff --git a/backend/tests/test_strip_think_blocks.py b/backend/tests/test_strip_think_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..0160d5f5a572a24f36a9897ee49a248b035c6931 --- /dev/null +++ b/backend/tests/test_strip_think_blocks.py @@ -0,0 +1,44 @@ +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from app.gpt.utils import strip_think_blocks + + +def test_paired_think_block_removed(): + text = "让我想想这个视频讲了什么…\n# 笔记标题\n\n正文内容" + assert strip_think_blocks(text) == "# 笔记标题\n\n正文内容" + + +def test_multiple_blocks_and_thinking_variant(): + text = "step 1开头step 2结尾" + assert strip_think_blocks(text) == "开头结尾" + + +def test_multiline_block_removed(): + text = "\n第一行\n第二行\n\n## 章节 *Content-[01:23]\n- 要点" + assert strip_think_blocks(text) == "## 章节 *Content-[01:23]\n- 要点" + + +def test_orphan_close_tag_keeps_tail(): + # 部分网关吞掉起始 ,正文只剩孤立的 + text = "嗯,用户想要笔记,我先分析转录……\n
\n# 真正的笔记" + assert strip_think_blocks(text) == "# 真正的笔记" + + +def test_unclosed_open_tag_drops_tail(): + # 输出被截断,只有起始标签 + text = "# 完整笔记\n\n正文\n这段推理被截断了" + assert strip_think_blocks(text) == "# 完整笔记\n\n正文" + + +def test_plain_text_untouched(): + text = "# 普通笔记\n\n没有任何标签,但提到了 think 这个词。" + assert strip_think_blocks(text) == text + + +def test_none_and_empty(): + assert strip_think_blocks(None) == "" + assert strip_think_blocks("") == "" + assert strip_think_blocks(" \n ") == "" diff --git a/backend/tests/test_sys_check_route.py b/backend/tests/test_sys_check_route.py new file mode 100644 index 0000000000000000000000000000000000000000..dd60b1edc0e265e6b612b99696f2b59cc928c42c --- /dev/null +++ b/backend/tests/test_sys_check_route.py @@ -0,0 +1,31 @@ +from contextlib import asynccontextmanager + +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from app import create_app + + +@asynccontextmanager +async def noop_lifespan(app: FastAPI): + yield + + +def test_sys_check_available_at_root_for_frontend_probe(): + app = create_app(lifespan=noop_lifespan) + client = TestClient(app) + + response = client.get("/sys_check") + + assert response.status_code == 200 + assert response.json()["code"] == 0 + + +def test_sys_check_still_available_under_api_prefix(): + app = create_app(lifespan=noop_lifespan) + client = TestClient(app) + + response = client.get("/api/sys_check") + + assert response.status_code == 200 + assert response.json()["code"] == 0 diff --git a/backend/tests/test_task_serial_executor.py b/backend/tests/test_task_serial_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..14e0e238464f8053aa074de156954f90ca9d02fa --- /dev/null +++ b/backend/tests/test_task_serial_executor.py @@ -0,0 +1,42 @@ +import importlib.util +import pathlib +import threading +import time +import unittest + + +ROOT = pathlib.Path(__file__).resolve().parents[1] +MODULE_PATH = ROOT / "app" / "services" / "task_serial_executor.py" +spec = importlib.util.spec_from_file_location("task_serial_executor", MODULE_PATH) +if spec is None or spec.loader is None: + raise ImportError("task_serial_executor module spec not found") +task_serial_executor = importlib.util.module_from_spec(spec) +spec.loader.exec_module(task_serial_executor) +SerialTaskExecutor = task_serial_executor.SerialTaskExecutor + + +class TestTaskSerialExecutor(unittest.TestCase): + def test_executor_runs_tasks_one_by_one(self): + executor = SerialTaskExecutor() + state_lock = threading.Lock() + state = {"active": 0, "peak_active": 0} + + def critical_work(): + with state_lock: + state["active"] += 1 + state["peak_active"] = max(state["peak_active"], state["active"]) + time.sleep(0.05) + with state_lock: + state["active"] -= 1 + + threads = [threading.Thread(target=lambda: executor.run(critical_work)) for _ in range(2)] + for t in threads: + t.start() + for t in threads: + t.join() + + self.assertEqual(state["peak_active"], 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/backend/tests/test_universal_gpt_checkpoint.py b/backend/tests/test_universal_gpt_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..a8f4bf7f8afcb3584d025128014af933e3fe69fd --- /dev/null +++ b/backend/tests/test_universal_gpt_checkpoint.py @@ -0,0 +1,148 @@ +import importlib.util +import json +import os +import pathlib +import sys +import tempfile +import types +import unittest +from pathlib import Path + + +def _install_stubs(): + app_mod = types.ModuleType("app") + gpt_pkg = types.ModuleType("app.gpt") + models_pkg = types.ModuleType("app.models") + + base_mod = types.ModuleType("app.gpt.base") + + class _GPT: + pass + + base_mod.GPT = _GPT + + prompt_builder_mod = types.ModuleType("app.gpt.prompt_builder") + + def _generate_base_prompt(**_kwargs): + return "prompt" + + prompt_builder_mod.generate_base_prompt = _generate_base_prompt + + prompt_mod = types.ModuleType("app.gpt.prompt") + prompt_mod.BASE_PROMPT = "" + prompt_mod.AI_SUM = "" + prompt_mod.SCREENSHOT = "" + prompt_mod.LINK = "" + prompt_mod.MERGE_PROMPT = "merge" + + utils_mod = types.ModuleType("app.gpt.utils") + + def _fix_markdown(text): + return text + + utils_mod.fix_markdown = _fix_markdown + utils_mod.strip_think_blocks = lambda text: (text or "").strip() + + request_chunker_mod = types.ModuleType("app.gpt.request_chunker") + + class _RequestChunker: + def __init__(self, *_args, **_kwargs): + pass + + def group_texts_by_budget(self, texts, _builder, **_kwargs): + return [texts] + + request_chunker_mod.RequestChunker = _RequestChunker + + gpt_model_mod = types.ModuleType("app.models.gpt_model") + + class _GPTSource: + pass + + gpt_model_mod.GPTSource = _GPTSource + + transcriber_model_mod = types.ModuleType("app.models.transcriber_model") + + class _TranscriptSegment: + def __init__(self, **kwargs): + self.start = kwargs.get("start", 0) + self.end = kwargs.get("end", 0) + self.text = kwargs.get("text", "") + + transcriber_model_mod.TranscriptSegment = _TranscriptSegment + + sys.modules.setdefault("app", app_mod) + sys.modules.setdefault("app.gpt", gpt_pkg) + sys.modules.setdefault("app.models", models_pkg) + sys.modules["app.gpt.base"] = base_mod + sys.modules["app.gpt.prompt_builder"] = prompt_builder_mod + sys.modules["app.gpt.prompt"] = prompt_mod + sys.modules["app.gpt.utils"] = utils_mod + sys.modules["app.gpt.request_chunker"] = request_chunker_mod + sys.modules["app.models.gpt_model"] = gpt_model_mod + sys.modules["app.models.transcriber_model"] = transcriber_model_mod + + +def _load_universal_gpt_class(): + _install_stubs() + root = pathlib.Path(__file__).resolve().parents[1] + module_path = root / "app" / "gpt" / "universal_gpt.py" + spec = importlib.util.spec_from_file_location("universal_gpt", module_path) + if spec is None or spec.loader is None: + raise ImportError("universal_gpt module spec not found") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module.UniversalGPT + + +UniversalGPT = _load_universal_gpt_class() + + +class _FailingCompletions: + def create(self, **_kwargs): + raise Exception("Error code: 524 - bad_response_status_code") + + +class _DummyChat: + def __init__(self): + self.completions = _FailingCompletions() + + +class _DummyModels: + @staticmethod + def list(): + return [] + + +class _DummyClient: + def __init__(self): + self.chat = _DummyChat() + self.models = _DummyModels() + + +class TestUniversalGPTCheckpoint(unittest.TestCase): + def test_merge_524_error_persists_checkpoint(self): + original_attempts = os.environ.get("OPENAI_RETRY_ATTEMPTS") + os.environ["OPENAI_RETRY_ATTEMPTS"] = "1" + gpt = UniversalGPT(_DummyClient(), model="mock-model") + try: + with tempfile.TemporaryDirectory() as tmp_dir: + gpt.checkpoint_dir = Path(tmp_dir) + + with self.assertRaises(Exception): + gpt._merge_partials(["part-a", "part-b"], "task-1", "sig-1") + + checkpoint_path = gpt._checkpoint_path("task-1") + self.assertTrue(checkpoint_path.exists()) + payload = json.loads(checkpoint_path.read_text(encoding="utf-8")) + self.assertEqual(payload["phase"], "merge") + self.assertEqual(payload["partials"], ["part-a", "part-b"]) + finally: + if original_attempts is None: + os.environ.pop("OPENAI_RETRY_ATTEMPTS", None) + else: + os.environ["OPENAI_RETRY_ATTEMPTS"] = original_attempts + + +if __name__ == "__main__": + unittest.main() diff --git a/backend/tests/test_universal_gpt_content_format.py b/backend/tests/test_universal_gpt_content_format.py new file mode 100644 index 0000000000000000000000000000000000000000..869066d464438e0a516494095b2719d242ae2422 --- /dev/null +++ b/backend/tests/test_universal_gpt_content_format.py @@ -0,0 +1,192 @@ +"""issue #282 回归测试:UniversalGPT 拼装 content 时按是否有图片切换 string / array 形态。 + +DeepSeek deepseek-chat 等非多模态模型只接受 ``content`` 为字符串,旧实现无条件 +emit ``[{"type":"text","text":...}]`` 导致 ``invalid_request_error``。 +""" +import importlib.util +import pathlib +import sys +import types +import unittest + + +def _install_stubs(): + app_mod = types.ModuleType("app") + gpt_pkg = types.ModuleType("app.gpt") + models_pkg = types.ModuleType("app.models") + + base_mod = types.ModuleType("app.gpt.base") + + class _GPT: + pass + + base_mod.GPT = _GPT + + prompt_builder_mod = types.ModuleType("app.gpt.prompt_builder") + + def _generate_base_prompt(**_kwargs): + return "PROMPT_BODY" + + prompt_builder_mod.generate_base_prompt = _generate_base_prompt + + prompt_mod = types.ModuleType("app.gpt.prompt") + prompt_mod.BASE_PROMPT = "" + prompt_mod.AI_SUM = "" + prompt_mod.SCREENSHOT = "" + prompt_mod.LINK = "" + prompt_mod.MERGE_PROMPT = "MERGE_HEAD" + + utils_mod = types.ModuleType("app.gpt.utils") + + def _fix_markdown(text): + return text + + utils_mod.fix_markdown = _fix_markdown + utils_mod.strip_think_blocks = lambda text: (text or "").strip() + + request_chunker_mod = types.ModuleType("app.gpt.request_chunker") + + class _RequestChunker: + def __init__(self, *_args, **_kwargs): + pass + + def group_texts_by_budget(self, texts, _builder, **_kwargs): + return [texts] + + request_chunker_mod.RequestChunker = _RequestChunker + + gpt_model_mod = types.ModuleType("app.models.gpt_model") + + class _GPTSource: + pass + + gpt_model_mod.GPTSource = _GPTSource + + transcriber_model_mod = types.ModuleType("app.models.transcriber_model") + + class _TranscriptSegment: + def __init__(self, **kwargs): + self.start = kwargs.get("start", 0) + self.end = kwargs.get("end", 0) + self.text = kwargs.get("text", "") + + transcriber_model_mod.TranscriptSegment = _TranscriptSegment + + sys.modules.setdefault("app", app_mod) + sys.modules.setdefault("app.gpt", gpt_pkg) + sys.modules.setdefault("app.models", models_pkg) + sys.modules["app.gpt.base"] = base_mod + sys.modules["app.gpt.prompt_builder"] = prompt_builder_mod + sys.modules["app.gpt.prompt"] = prompt_mod + sys.modules["app.gpt.utils"] = utils_mod + sys.modules["app.gpt.request_chunker"] = request_chunker_mod + sys.modules["app.models.gpt_model"] = gpt_model_mod + sys.modules["app.models.transcriber_model"] = transcriber_model_mod + + +def _load_universal_gpt_class(): + _install_stubs() + root = pathlib.Path(__file__).resolve().parents[1] + module_path = root / "app" / "gpt" / "universal_gpt.py" + spec = importlib.util.spec_from_file_location( + "universal_gpt_content_format", module_path + ) + if spec is None or spec.loader is None: + raise ImportError("universal_gpt module spec not found") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module.UniversalGPT + + +UniversalGPT = _load_universal_gpt_class() + + +class _DummyClient: + """create_messages 不会真的调用 client,给个空壳即可。""" + + +def _make_gpt(): + return UniversalGPT(_DummyClient(), model="deepseek-chat") + + +class TestCreateMessagesContentFormat(unittest.TestCase): + """覆盖 create_messages 在不同 video_img_urls 输入下的输出形态。""" + + def test_no_images_emits_string_content(self): + """无图片时 content 为 str(DeepSeek / 非多模态模型可解析)。""" + gpt = _make_gpt() + + messages = gpt.create_messages(segments=[]) + + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]["role"], "user") + self.assertIsInstance(messages[0]["content"], str) + self.assertEqual(messages[0]["content"], "PROMPT_BODY") + + def test_empty_image_list_emits_string_content(self): + """显式传入空列表也要走纯文本分支,避免图片字段误触发。""" + gpt = _make_gpt() + + messages = gpt.create_messages(segments=[], video_img_urls=[]) + + self.assertIsInstance(messages[0]["content"], str) + + def test_with_images_emits_multimodal_array(self): + """有图片时保留多模态 array 形态,确保多模态模型功能不退化。""" + gpt = _make_gpt() + + messages = gpt.create_messages( + segments=[], + video_img_urls=["https://example.com/a.jpg", "https://example.com/b.jpg"], + ) + + content = messages[0]["content"] + self.assertIsInstance(content, list) + self.assertEqual(len(content), 3) # 1 text + 2 images + self.assertEqual(content[0], {"type": "text", "text": "PROMPT_BODY"}) + self.assertEqual(content[1]["type"], "image_url") + self.assertEqual(content[1]["image_url"]["url"], "https://example.com/a.jpg") + # 不应携带 detail 字段:MiniMax 等兼容接口对 detail:"auto" 报 400 (2013), + # OpenAI 缺省值本来就是 auto + self.assertNotIn("detail", content[1]["image_url"]) + self.assertEqual(content[2]["image_url"]["url"], "https://example.com/b.jpg") + + def test_no_image_url_field_when_no_images(self): + """纯文本响应里不应该出现 image_url 关键字 —— 这是触发 DeepSeek 400 的根因。""" + gpt = _make_gpt() + + messages = gpt.create_messages(segments=[]) + + import json + serialized = json.dumps(messages, ensure_ascii=False) + self.assertNotIn("image_url", serialized) + + +class TestBuildMergeMessagesContentFormat(unittest.TestCase): + """合并阶段从不带图片,应该统一走 string content 路径。""" + + def test_merge_messages_use_string_content(self): + """否则长视频 chunk 后的合并阶段还会复现 issue #282 错误。""" + gpt = _make_gpt() + + messages = gpt._build_merge_messages(["partial-A", "partial-B"]) + + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0]["role"], "user") + self.assertIsInstance(messages[0]["content"], str) + self.assertIn("MERGE_HEAD", messages[0]["content"]) + self.assertIn("partial-A", messages[0]["content"]) + self.assertIn("partial-B", messages[0]["content"]) + + def test_merge_messages_no_image_url_field(self): + gpt = _make_gpt() + + messages = gpt._build_merge_messages(["x"]) + + import json + serialized = json.dumps(messages, ensure_ascii=False) + self.assertNotIn("image_url", serialized) + + +if __name__ == "__main__": + unittest.main() diff --git a/backend/tests/test_video_reader_dedupe.py b/backend/tests/test_video_reader_dedupe.py new file mode 100644 index 0000000000000000000000000000000000000000..25d4e044cbf1b7d7030da21a73984f917a71a2ab --- /dev/null +++ b/backend/tests/test_video_reader_dedupe.py @@ -0,0 +1,142 @@ +import importlib.util +import pathlib +import re +import sys +import tempfile +import types +import unittest +from unittest.mock import patch + + +def _install_stubs(): + app_mod = types.ModuleType("app") + utils_pkg = types.ModuleType("app.utils") + + logger_mod = types.ModuleType("app.utils.logger") + + class _Logger: + @staticmethod + def info(*_args, **_kwargs): + return None + + @staticmethod + def warning(*_args, **_kwargs): + return None + + @staticmethod + def error(*_args, **_kwargs): + return None + + def _get_logger(_name): + return _Logger() + + logger_mod.get_logger = _get_logger + + path_helper_mod = types.ModuleType("app.utils.path_helper") + ffmpeg_mod = types.ModuleType("ffmpeg") + + pil_mod = types.ModuleType("PIL") + pil_image_mod = types.ModuleType("PIL.Image") + pil_draw_mod = types.ModuleType("PIL.ImageDraw") + pil_font_mod = types.ModuleType("PIL.ImageFont") + + class _FakeImage: + pass + + class _FakeImageDraw: + @staticmethod + def Draw(*_args, **_kwargs): + return None + + class _FakeImageFont: + @staticmethod + def truetype(*_args, **_kwargs): + return None + + @staticmethod + def load_default(): + return None + + pil_image_mod.Image = _FakeImage + pil_draw_mod.ImageDraw = _FakeImageDraw + pil_font_mod.ImageFont = _FakeImageFont + + def _get_app_dir(name): + return name + + path_helper_mod.get_app_dir = _get_app_dir + ffmpeg_mod.probe = lambda *_args, **_kwargs: {"format": {"duration": "0"}} + + sys.modules.setdefault("app", app_mod) + sys.modules.setdefault("app.utils", utils_pkg) + sys.modules["PIL"] = pil_mod + sys.modules["PIL.Image"] = pil_image_mod + sys.modules["PIL.ImageDraw"] = pil_draw_mod + sys.modules["PIL.ImageFont"] = pil_font_mod + sys.modules["ffmpeg"] = ffmpeg_mod + sys.modules["app.utils.logger"] = logger_mod + sys.modules["app.utils.path_helper"] = path_helper_mod + + +def _load_video_reader_module(): + _install_stubs() + root = pathlib.Path(__file__).resolve().parents[1] + module_path = root / "app" / "utils" / "video_reader.py" + spec = importlib.util.spec_from_file_location("video_reader", module_path) + if spec is None or spec.loader is None: + raise ImportError("video_reader module spec not found") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +video_reader_module = _load_video_reader_module() +VideoReader = video_reader_module.VideoReader + + +def _make_fake_ffmpeg_runner(colors_by_second): + def _runner(cmd, check=True): + output_path = next((arg for arg in cmd if isinstance(arg, str) and arg.endswith(".jpg")), None) + if output_path is None: + raise AssertionError("Output path not found in ffmpeg cmd") + match = re.search(r"frame_(\d{2})_(\d{2})\.jpg$", output_path) + if match is None: + raise AssertionError("Unexpected output path") + sec = int(match.group(1)) * 60 + int(match.group(2)) + payload = colors_by_second[sec] + with open(output_path, "wb") as f: + f.write(payload) + return 0 + + return _runner + + +class TestVideoReaderDeduplicateFrames(unittest.TestCase): + def test_extract_frames_skips_adjacent_duplicates_when_enabled(self): + with tempfile.TemporaryDirectory() as tmp_dir: + frame_dir = pathlib.Path(tmp_dir) / "frames" + grid_dir = pathlib.Path(tmp_dir) / "grids" + reader = VideoReader( + video_path="dummy.mp4", + frame_interval=1, + frame_dir=str(frame_dir), + grid_dir=str(grid_dir), + ) + + fake_colors = { + 0: b"frame-a", + 1: b"frame-a", + 2: b"frame-b", + 3: b"frame-b", + } + + with patch.object(video_reader_module.ffmpeg, "probe", return_value={"format": {"duration": "4"}}), \ + patch.object(video_reader_module.subprocess, "run", side_effect=_make_fake_ffmpeg_runner(fake_colors)): + paths = reader.extract_frames(max_frames=10) + + names = [pathlib.Path(p).name for p in paths] + self.assertEqual(names, ["frame_00_00.jpg", "frame_00_02.jpg"]) + + +if __name__ == "__main__": + unittest.main()