Spaces:
Running
Running
| """ | |
| OpenWolf HF Spaces — FastAPI 入口 | |
| ML 依赖在 startup 时自动安装,保持 Docker 构建轻量 | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import asyncio | |
| import time | |
| import threading | |
| import uuid | |
| import requests | |
| import re | |
| import hashlib | |
| import random | |
| from pathlib import Path | |
| from fastapi import FastAPI, Request, HTTPException, BackgroundTasks | |
| import concurrent.futures | |
| from fastapi.responses import JSONResponse | |
| sys.path.insert(0, "/app") | |
| # ── 设默认环境变量 ── | |
| os.environ.setdefault("ISSUE_NUMBER", "0") | |
| os.environ.setdefault("COMMENT_BODY", "") | |
| os.environ.setdefault("COMMENT_USER", "spaces") | |
| os.environ.setdefault("GITHUB_REPO", "hughyonng/OpenWolf") | |
| os.environ.setdefault("GITHUB_TOKEN", os.environ.get("GITHUB_PAT", "")) | |
| os.environ.setdefault("OPENWOLF_PAT", os.environ.get("GITHUB_PAT", "")) | |
| os.environ.setdefault("TELEGRAM_BOT_TOKEN", "") | |
| os.environ.setdefault("TELEGRAM_CHAT_ID", "") | |
| app = FastAPI(title="OpenWolf Agent with Cloud Acceleration") | |
| async def _catch_all(request: Request, exc: Exception): | |
| print(f"[FATAL] {request.method} {request.url.path}: {exc}") | |
| return JSONResponse({"ok": False, "error": str(exc)}, status_code=500) | |
| _ready = False | |
| _model_loading = False | |
| _model_loaded = False | |
| _background_tasks = set() | |
| _extract_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="extract") | |
| _translate_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2, thread_name_prefix="translate") | |
| # GGUF 模型及全局状态 | |
| _llama_model = None | |
| _llama_lock = threading.Lock() | |
| _infer_lock = threading.Lock() | |
| _translate_tasks = {} | |
| _cancelled_tasks: set = set() # 存 chat_id,被取消时加入 | |
| _analyze_tasks = {} | |
| _analyze_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2, thread_name_prefix="analyze") | |
| _task_tasks = {} | |
| _task_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2, thread_name_prefix="text_task") | |
| _ocr_pdf_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2, thread_name_prefix="ocr_pdf") | |
| _ocr_tasks = {} | |
| _OCR_TASK_DIR = Path("/app/.ocr_tasks") | |
| _OCR_TASK_DIR.mkdir(parents=True, exist_ok=True) | |
| def _ocr_task_path(task_id: str) -> Path: | |
| return _OCR_TASK_DIR / f"{task_id}.json" | |
| def _ocr_task_write(task_id: str, data: dict): | |
| """同时写内存和文件,保证重启后可恢复。""" | |
| _ocr_tasks[task_id] = data | |
| try: | |
| _ocr_task_path(task_id).write_text( | |
| json.dumps(data, ensure_ascii=False), encoding="utf-8" | |
| ) | |
| except Exception as e: | |
| print(f"[ocr_task] 文件写入失败 {task_id}: {e}") | |
| def _ocr_task_read(task_id: str) -> dict | None: | |
| """先查内存,没有则从文件恢复(Spaces 重启场景)。""" | |
| if task_id in _ocr_tasks: | |
| return _ocr_tasks[task_id] | |
| p = _ocr_task_path(task_id) | |
| if p.exists(): | |
| try: | |
| data = json.loads(p.read_text(encoding="utf-8")) | |
| _ocr_tasks[task_id] = data # 恢复到内存 | |
| return data | |
| except Exception as e: | |
| print(f"[ocr_task] 文件读取失败 {task_id}: {e}") | |
| return None | |
| def _ocr_task_delete(task_id: str): | |
| pass | |
| def _ocr_task_cleanup_old(): | |
| """清理超过1小时的旧任务文件,在 startup 时运行一次。""" | |
| for d in [_OCR_TASK_DIR, _TRANSLATE_TASK_DIR, _ANALYZE_TASK_DIR, _TASK_TASK_DIR]: | |
| try: | |
| cutoff = time.time() - 3600 | |
| for p in d.glob("*.json"): | |
| if p.stat().st_mtime < cutoff: | |
| p.unlink(missing_ok=True) | |
| except Exception as e: | |
| print(f"[task_cleanup] 清理 {d.name} 失败: {e}") | |
| # ── 翻译任务持久化 ── | |
| _TRANSLATE_TASK_DIR = Path("/app/.translate_tasks") | |
| _TRANSLATE_TASK_DIR.mkdir(parents=True, exist_ok=True) | |
| def _translate_task_path(task_id: str) -> Path: | |
| return _TRANSLATE_TASK_DIR / f"{task_id}.json" | |
| def _translate_task_write(task_id: str, data: dict): | |
| _translate_tasks[task_id] = data | |
| try: | |
| _translate_task_path(task_id).write_text(json.dumps(data, ensure_ascii=False), encoding="utf-8") | |
| except Exception as e: | |
| print(f"[trans_task] 文件写入失败 {task_id}: {e}") | |
| def _translate_task_read(task_id: str) -> dict | None: | |
| if task_id in _translate_tasks: | |
| return _translate_tasks[task_id] | |
| p = _translate_task_path(task_id) | |
| if p.exists(): | |
| try: | |
| data = json.loads(p.read_text(encoding="utf-8")) | |
| _translate_tasks[task_id] = data | |
| return data | |
| except Exception as e: | |
| print(f"[trans_task] 文件读取失败 {task_id}: {e}") | |
| return None | |
| # ── 分析任务持久化 ── | |
| _ANALYZE_TASK_DIR = Path("/app/.analyze_tasks") | |
| _ANALYZE_TASK_DIR.mkdir(parents=True, exist_ok=True) | |
| def _analyze_task_path(task_id: str) -> Path: | |
| return _ANALYZE_TASK_DIR / f"{task_id}.json" | |
| def _analyze_task_write(task_id: str, data: dict): | |
| _analyze_tasks[task_id] = data | |
| try: | |
| _analyze_task_path(task_id).write_text(json.dumps(data, ensure_ascii=False), encoding="utf-8") | |
| except Exception as e: | |
| print(f"[analyze_task] 文件写入失败 {task_id}: {e}") | |
| def _analyze_task_read(task_id: str) -> dict | None: | |
| if task_id in _analyze_tasks: | |
| return _analyze_tasks[task_id] | |
| p = _analyze_task_path(task_id) | |
| if p.exists(): | |
| try: | |
| data = json.loads(p.read_text(encoding="utf-8")) | |
| _analyze_tasks[task_id] = data | |
| return data | |
| except Exception as e: | |
| print(f"[analyze_task] 文件读取失败 {task_id}: {e}") | |
| return None | |
| # ── 通用任务持久化 ── | |
| _TASK_TASK_DIR = Path("/app/.task_tasks") | |
| _TASK_TASK_DIR.mkdir(parents=True, exist_ok=True) | |
| def _task_task_path(task_id: str) -> Path: | |
| return _TASK_TASK_DIR / f"{task_id}.json" | |
| def _task_task_write(task_id: str, data: dict): | |
| _task_tasks[task_id] = data | |
| try: | |
| _task_task_path(task_id).write_text(json.dumps(data, ensure_ascii=False), encoding="utf-8") | |
| except Exception as e: | |
| print(f"[generic_task] 文件写入失败 {task_id}: {e}") | |
| def _task_task_read(task_id: str) -> dict | None: | |
| if task_id in _task_tasks: | |
| return _task_tasks[task_id] | |
| p = _task_task_path(task_id) | |
| if p.exists(): | |
| try: | |
| data = json.loads(p.read_text(encoding="utf-8")) | |
| _task_tasks[task_id] = data | |
| return data | |
| except Exception as e: | |
| print(f"[generic_task] 文件读取失败 {task_id}: {e}") | |
| return None | |
| # 🌟 优化:扫描版 PDF 的 Lazy-OCR 提取单次页数,增加到 25 页,完美对应 5000 汉字生成量 | |
| PAGES_PER_CHUNK = 25 | |
| # ══════════════════════════════════════════════════════════════════ | |
| # ModelScope 每日额度持久化管理器 | |
| # ══════════════════════════════════════════════════════════════════ | |
| class ModelScopeQuotaManager: | |
| def __init__(self): | |
| self.lock = threading.Lock() | |
| self.file_path = Path("/app/.translate_cache/modelscope_quota.json") | |
| self.file_path.parent.mkdir(parents=True, exist_ok=True) | |
| self._load() | |
| def _load(self): | |
| if self.file_path.exists(): | |
| try: | |
| self.data = json.loads(self.file_path.read_text(encoding="utf-8")) | |
| except Exception: | |
| self.data = {} | |
| else: | |
| self.data = {} | |
| def _save(self): | |
| try: | |
| self.file_path.write_text(json.dumps(self.data, ensure_ascii=False), encoding="utf-8") | |
| except Exception as e: | |
| print(f"[quota] 保存配额记录失败: {e}") | |
| def increment(self, model_name: str) -> bool: | |
| with self.lock: | |
| self._load() | |
| today = time.strftime("%Y-%m-%d", time.localtime()) | |
| if self.data.get("date") != today: | |
| self.data = {"date": today, "total": 0, "usage": {}, "fails": {}} | |
| # ★ 连续失败 5 次 → 当日跳过 | |
| fail_count = self.data.get("fails", {}).get(model_name, 0) | |
| if fail_count >= 5: | |
| return False | |
| current_usage = self.data["usage"].get(model_name, 0) | |
| if self.data["total"] >= 1000: | |
| return False | |
| if current_usage >= 200: | |
| return False | |
| self.data["usage"][model_name] = current_usage + 1 | |
| self.data["total"] += 1 | |
| self._save() | |
| return True | |
| def record_fail(self, model_name: str): | |
| with self.lock: | |
| self._load() | |
| fails = self.data.setdefault("fails", {}) | |
| c = fails.get(model_name, 0) + 1 | |
| fails[model_name] = c | |
| if c == 5: | |
| print(f"[quota] 模型 {model_name} 连续失败 5 次,今日跳过") | |
| self._save() | |
| def record_success(self, model_name: str): | |
| with self.lock: | |
| self._load() | |
| fails = self.data.setdefault("fails", {}) | |
| if fails.get(model_name, 0) > 0: | |
| print(f"[quota] 模型 {model_name} 恢复响应,重置失败计数") | |
| fails[model_name] = 0 | |
| self._save() | |
| _quota_manager = ModelScopeQuotaManager() | |
| async def startup(): | |
| global _ready | |
| _ready = True | |
| print("[startup] OpenWolf Spaces ready") | |
| # 清理上次重启前遗留的旧任务文件 | |
| threading.Thread(target=_ocr_task_cleanup_old, daemon=True).start() | |
| def _ensure_models(): | |
| _models_dir = Path("/app/models") | |
| _bge_dir = _models_dir / "bge-m3" | |
| try: | |
| if not (_bge_dir / "config.json").exists(): | |
| print("[models] Downloading bge-m3 (2.2GB)...") | |
| t0 = time.time() | |
| from sentence_transformers import SentenceTransformer | |
| _ = SentenceTransformer("BAAI/bge-m3", device="cpu") | |
| print(f"[models] bge-m3 done in {time.time()-t0:.1f}s") | |
| except Exception as e: | |
| print(f"[models] bge-m3 download failed: {e}") | |
| try: | |
| _gguf_files = [ | |
| ("HY-MT1.5-1.8B-Q4_K_M.gguf", 1.13), | |
| ("HY-MT1.5-1.8B-Q8_0.gguf", 1.91), | |
| ] | |
| _gguf_to_download = None | |
| for _name, _gb in _gguf_files: | |
| _p = _models_dir / "translate" / _name | |
| if _p.exists(): | |
| _gguf_to_download = None | |
| break | |
| if _gguf_to_download is None: | |
| _gguf_to_download = (_name, _gb) | |
| if _gguf_to_download: | |
| _name, _gb = _gguf_to_download | |
| print(f"[models] Downloading {_name} ({_gb}GB)...") | |
| t0 = time.time() | |
| from huggingface_hub import hf_hub_download | |
| hf_hub_download( | |
| repo_id="tencent/HY-MT1.5-1.8B-GGUF", | |
| filename=_name, | |
| local_dir=str(_models_dir / "translate"), | |
| ) | |
| print(f"[models] GGUF done in {time.time()-t0:.1f}s") | |
| except Exception as e: | |
| print(f"[models] GGUF download failed: {e}") | |
| print("[models] All models ready") | |
| try: | |
| global _model_loaded, _model_loading | |
| print("[warmup] Loading bge-m3...") | |
| t0 = time.time() | |
| from sentence_transformers import SentenceTransformer | |
| _ = SentenceTransformer("BAAI/bge-m3", device="cpu") | |
| _model_loaded = True | |
| _model_loading = False | |
| print(f"[warmup] bge-m3 loaded in {time.time()-t0:.1f}s") | |
| except Exception as e: | |
| print(f"[warmup] bge-m3 FAILED: {e}") | |
| _model_loading = False | |
| threading.Thread(target=_ensure_models, daemon=True).start() | |
| async def health(): | |
| env_keys = [ | |
| "MODELSCOPE_API_KEY", "OPENROUTER_API_KEY", "GOOGLE_API_KEY", "CHATANYWHERE_API_KEY", | |
| "GROQ_API_KEY", "GITHUB_PAT", "GITHUB_REPO", "TELEGRAM_BOT_TOKEN", "TELEGRAM_CHAT_ID", | |
| "OPENWOLF_PAT", "SILICONFLOW_API_KEY", "ZHIPU_API_KEY", "NVIDIA_API_KEY" | |
| ] | |
| env_status = {k: "✅" if os.environ.get(k) else "❌" for k in env_keys} | |
| return {"status": "ok", "ready": _ready, "env": env_status} | |
| async def extract_r2(request: Request): | |
| """ | |
| 从 R2 公开 URL 下载文件并提取文本内容,用于 SPA 文档对话上传。 | |
| """ | |
| try: | |
| body = await request.json() | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Invalid JSON") | |
| url = body.get("url", "") | |
| file_name = body.get("file_name", "document.pdf") | |
| if not url: | |
| return {"ok": False, "error": "url required"} | |
| ext = file_name.rsplit(".", 1)[-1].lower() if "." in file_name else "pdf" | |
| import requests as _req | |
| import uuid as _uuid | |
| from pathlib import Path | |
| try: | |
| r = _req.get(url, timeout=120, stream=True) | |
| if r.status_code != 200: | |
| return {"ok": False, "error": f"下载失败 HTTP {r.status_code}"} | |
| local_path = Path("/app") / f"inputs/{_uuid.uuid4().hex}.{ext}" | |
| local_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(local_path, "wb") as f: | |
| for chunk in r.iter_content(chunk_size=65536): | |
| f.write(chunk) | |
| text = "" | |
| if ext == "pdf": | |
| import pdfplumber | |
| with pdfplumber.open(local_path) as p: | |
| for page in p.pages: | |
| t = page.extract_text() | |
| if t: | |
| text += t + "\n" | |
| elif ext in ("txt", "md", "csv", "json"): | |
| text = local_path.read_text(encoding="utf-8", errors="ignore") | |
| elif ext == "docx": | |
| import docx | |
| d = docx.Document(local_path) | |
| text = "\n".join(p.text for p in d.paragraphs) | |
| elif ext == "epub": | |
| import zipfile, xml.etree.ElementTree as ET | |
| with zipfile.ZipFile(local_path, "r") as z: | |
| for name in z.namelist(): | |
| if name.endswith((".xhtml", ".html", ".htm")): | |
| try: | |
| root = ET.fromstring(z.read(name)) | |
| for elem in root.iter(): | |
| if elem.text: text += elem.text.strip() + " " | |
| if elem.tail: text += elem.tail.strip() + " " | |
| text += "\n\n" | |
| except Exception: | |
| pass | |
| try: | |
| local_path.unlink() | |
| except Exception: | |
| pass | |
| if not text.strip(): | |
| return {"ok": False, "error": "无法提取文本内容"} | |
| return {"ok": True, "text": text[:50000], "file_name": file_name} | |
| except Exception as e: | |
| return {"ok": False, "error": f"提取失败: {e}"} | |
| # ══════════════════════════════════════════════════════════════════ | |
| # 异步 OCR 提取端点(PDF 转换面板用,不含 LLM,纯物理提取) | |
| # ══════════════════════════════════════════════════════════════════ | |
| async def ocr_pdf(request: Request): | |
| try: | |
| body = await request.json() | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Invalid JSON") | |
| url = body.get("url") | |
| file_name = body.get("file_name", "document.pdf") | |
| callback_url = body.get("callback_url") | |
| ocr_model = body.get("model", "") | |
| if not url: | |
| return {"ok": False, "error": "R2 URL is required"} | |
| task_id = str(uuid.uuid4()) | |
| _ocr_task_write(task_id, {"status": "processing", "progress": 0, "result": None, | |
| "file_name": file_name, "callback_url": callback_url or ""}) | |
| def _async_ocr_worker(t_id, pdf_url, fn, cb_url, ocr_model=""): | |
| extraction_method = "" # 记录实际使用的提取方式 | |
| def push(pct, status="processing", text=None): | |
| current = _ocr_task_read(t_id) or {} | |
| current["progress"] = pct | |
| current["method"] = extraction_method # ★ 每次 push 都带上当前方式 | |
| if status != "processing": | |
| current["status"] = status | |
| if text is not None: | |
| current["result"] = text | |
| _ocr_task_write(t_id, current) | |
| if not cb_url: | |
| return | |
| payload = {"ok": True, "task_id": t_id, "status": status, "progress": pct, | |
| "method": extraction_method} # ★ 回调 payload 带 method | |
| if text is not None: | |
| payload["text"] = text | |
| payload["file_name"] = fn | |
| for attempt in range(3): | |
| try: | |
| r = requests.post(cb_url, json=payload, timeout=15) | |
| if r.status_code < 500: | |
| return | |
| except Exception as e: | |
| print(f"[ocr] 推送失败 attempt={attempt+1}: {e}") | |
| time.sleep(2 ** attempt) | |
| try: | |
| resp = requests.get(pdf_url, timeout=120, stream=True) | |
| if resp.status_code != 200: | |
| _ocr_task_write(t_id, {"status": "error", "progress": 0, | |
| "error": f"Download failed: HTTP {resp.status_code}"}) | |
| if cb_url: | |
| requests.post(cb_url, json={"ok": False, "task_id": t_id, | |
| "error": f"HTTP {resp.status_code}"}, timeout=30) | |
| return | |
| local_path = Path("/app") / f"inputs/{t_id}_{fn}" | |
| local_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(local_path, "wb") as f: | |
| for chunk in resp.iter_content(chunk_size=65536): | |
| f.write(chunk) | |
| import fitz | |
| doc = fitz.open(local_path) | |
| total_pages = len(doc) | |
| doc.close() | |
| full_text = "" | |
| _ocr_task_write(t_id, {"status": "processing", "progress": 1, | |
| "file_name": fn, "callback_url": cb_url or "", "method": ""}) | |
| push(1) | |
| # ★ 先试 pdfplumber 电子版提取 | |
| if ocr_model != "llamaparse": | |
| try: | |
| import pdfplumber as _pp | |
| with _pp.open(local_path) as _p: | |
| _text_parts = [] | |
| for _page in _p.pages: | |
| _t = _page.extract_text() | |
| if _t: | |
| _text_parts.append(_t) | |
| if _text_parts: | |
| _joined = "\n\n".join(_text_parts) | |
| if len(_joined.strip()) > 200: | |
| full_text = _joined | |
| extraction_method = "pdfplumber" # ★ | |
| print(f"[ocr] pdfplumber 提取成功,跳过 OCR ({len(full_text)} 字符)") | |
| except Exception as _e: | |
| print(f"[ocr] pdfplumber 提取失败: {_e}") | |
| if not full_text: | |
| if ocr_model == "llamaparse": | |
| extraction_method = "llamaparse" # ★ | |
| push(5) | |
| full_text = _ocr_pdf_via_llamaparse(local_path) | |
| if not full_text: | |
| raise Exception("LlamaParse 未返回任何文本,请检查 API Key 或文件格式") | |
| push(95) | |
| else: | |
| extraction_method = f"ocr:{ocr_model or 'paddle'}" # ★ | |
| push_interval = max(1, total_pages // 10) | |
| for i in range(total_pages): | |
| page_text = _ocr_page_via_siliconflow(local_path, i, ocr_model) | |
| if page_text: | |
| full_text += page_text + "\n\n" | |
| pct = max(1, int(((i + 1) / total_pages) * 100)) | |
| _ocr_task_write(t_id, {"status": "processing", "progress": pct, | |
| "file_name": fn, "callback_url": cb_url or "", | |
| "method": extraction_method}) | |
| if i % push_interval == 0 or i == total_pages - 1: | |
| push(pct) | |
| try: | |
| local_path.unlink() | |
| except Exception: | |
| pass | |
| _ok, _reason = _check_ocr_quality(full_text) | |
| if not _ok: | |
| _ocr_task_write(t_id, {"status": "error", "progress": 0, | |
| "error": f"OCR质量检测失败:{_reason}", | |
| "file_name": fn, "callback_url": cb_url or ""}) | |
| if cb_url: | |
| requests.post(cb_url, json={"ok": False, "task_id": t_id, | |
| "error": f"OCR质量检测失败:{_reason}"}, timeout=30) | |
| raise Exception(f"OCR质量检测失败:{_reason}") | |
| _ocr_task_write(t_id, {"status": "done", "progress": 100, | |
| "file_name": fn, "callback_url": cb_url or "", | |
| "method": extraction_method}) | |
| push(100, status="done", text=full_text) | |
| except Exception as e: | |
| _ocr_task_write(t_id, {"status": "error", "progress": 0, "error": str(e), | |
| "file_name": fn, "callback_url": cb_url or ""}) | |
| if cb_url: | |
| requests.post(cb_url, json={"ok": False, "task_id": t_id, "error": str(e)}, timeout=30) | |
| _ocr_pdf_pool.submit(_async_ocr_worker, task_id, url, file_name, callback_url, ocr_model) | |
| return {"ok": True, "task_id": task_id} | |
| async def ocr_pdf_check(task_id: str): | |
| task = _ocr_task_read(task_id) # ← 先内存后文件,重启后仍可查 | |
| if not task: | |
| return {"ok": False, "error": "Task not found"} | |
| return { | |
| "ok": True, | |
| "status": task.get("status", "processing"), | |
| "progress": task.get("progress", 0), | |
| "error": task.get("error"), | |
| "method": task.get("method", ""), | |
| } | |
| # ══════════════════════════════════════════════════════════════════ | |
| # 辅助洗涤函数:彻底清洗并过滤大模型返回的思维链(<think>...</think>) | |
| # ══════════════════════════════════════════════════════════════════ | |
| def clean_think_tags(text: str) -> str: | |
| """ | |
| 清洗大模型输出文本中残留的 <think>...</think> 标签及其包裹的所有非译文内容 | |
| """ | |
| if not text: | |
| return "" | |
| # 移除完整的带开闭合标签思维链 | |
| cleaned = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL) | |
| # 防御性过滤:如果由于截断模型只输出了 <think> 但未闭合,直接过滤掉 <think> 后的所有文字 | |
| cleaned = re.sub(r"<think>.*$", "", cleaned, flags=re.DOTALL) | |
| # 清理行首尾杂质 | |
| return cleaned.strip() | |
| # ══════════════════════════════════════════════════════════════════ | |
| # 工具函数群:动态密钥加载与语系检测 | |
| # ══════════════════════════════════════════════════════════════════ | |
| def get_multi_api_keys(prefix: str) -> list: | |
| keys = [] | |
| for i in range(1, 10): | |
| val = os.environ.get(f"{prefix}_{i}") or os.environ.get(f"{prefix}{i}") | |
| if val: | |
| keys.append(val.strip()) | |
| single = os.environ.get(prefix) | |
| if single and single.strip() not in keys: | |
| keys.append(single.strip()) | |
| return keys | |
| def detect_japanese_korean(text: str) -> str: | |
| if re.search(r"[\u3040-\u309f\u30a0-\u30ff]", text): | |
| return "ja" | |
| if re.search(r"[\uac00-\ud7af]", text): | |
| return "ko" | |
| return "en" | |
| # 将 target_size 调整为 8000 字符(对应 ~1500 英文单词,译出 ~2500 汉字) | |
| def semantic_split(text: str, target_size: int = 8000) -> list: | |
| """ | |
| 按段落逻辑切分,不再执行强熔断。 | |
| """ | |
| paragraphs = text.split("\n") | |
| chunks = [] | |
| current_chunk = [] | |
| current_size = 0 | |
| for para in paragraphs: | |
| para_clean = para.strip() | |
| if not para_clean: | |
| continue | |
| # 过滤页码、页眉页脚噪音 | |
| if len(para_clean) < 80 and any(kw in para_clean.lower() for kw in ["page", "vol.", "no.", "issn", "doi:", "http://", "https://"]): | |
| continue | |
| para_size = len(para_clean) | |
| if current_size + para_size > target_size and current_chunk: | |
| chunks.append("\n\n".join(current_chunk)) | |
| current_chunk = [para_clean] | |
| current_size = para_size | |
| else: | |
| current_chunk.append(para_clean) | |
| current_size += para_size + 2 | |
| if current_chunk: | |
| chunks.append("\n\n".join(current_chunk)) | |
| return chunks if chunks else [text] | |
| # ══════════════════════════════════════════════════════════════════ | |
| # OCR 输出质量检测 | |
| # ══════════════════════════════════════════════════════════════════ | |
| def _check_ocr_quality(text: str) -> tuple: | |
| """ | |
| 检测 OCR 输出是否有效,返回 (is_ok, reason) | |
| """ | |
| if not text or len(text.strip()) < 50: | |
| return False, "输出为空" | |
| brace_ratio = text.count('}') / len(text) | |
| if brace_ratio > 0.3: | |
| return False, f"疑似CID字体乱码(}}占比{brace_ratio:.0%})" | |
| lines = text.split('\n') | |
| hash_lines = sum(1 for l in lines if l.strip().startswith('#####')) | |
| if len(lines) > 10 and hash_lines / len(lines) > 0.3: | |
| return False, f"OCR识别失败({hash_lines}/{len(lines)}行为无效内容)" | |
| cjk_count = sum(1 for c in text if '一' <= c <= '鿿' or '' <= c <= 'ヿ') | |
| if len(text) > 200 and cjk_count / len(text) < 0.01: | |
| return False, "未检测到有效中日文字符" | |
| return True, "ok" | |
| # ══════════════════════════════════════════════════════════════════ | |
| # 在线 OCR 模块(硅基流动视觉大模型 Lazy 加载版) | |
| # ══════════════════════════════════════════════════════════════════ | |
| def _ocr_page_via_siliconflow(pdf_path: Path, page_index: int, ocr_model: str = "") -> str: | |
| import base64 | |
| sf_key = os.environ.get("SILICONFLOW_API_KEY") | |
| if not sf_key: | |
| print("[ocr] 未配置 SILICONFLOW_API_KEY") | |
| return "" | |
| try: | |
| import fitz | |
| doc = fitz.open(pdf_path) | |
| if page_index >= len(doc): | |
| return "" | |
| page = doc[page_index] | |
| pix = page.get_pixmap(dpi=300) | |
| img_bytes = pix.tobytes("png") | |
| base64_img = base64.b64encode(img_bytes).decode("utf-8") | |
| if not ocr_model: | |
| ocr_model = os.environ.get("OCRAI_OCR_MODEL", "PaddlePaddle/PaddleOCR-VL-1.5") | |
| url = "https://api.siliconflow.cn/v1/chat/completions" | |
| headers = {"Authorization": f"Bearer {sf_key.strip()}", "Content-Type": "application/json"} | |
| body = { | |
| "model": ocr_model, | |
| "messages": [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": "完整提取图片中的所有中文文字,保留原文格式和段落。Extract all Chinese text from this image, preserve original paragraphs."}, | |
| {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_img}"}} | |
| ] | |
| } | |
| ], | |
| "temperature": 0.1 | |
| } | |
| r = requests.post(url, headers=headers, json=body, timeout=40) | |
| if r.status_code == 200: | |
| extracted_text = r.json()["choices"][0]["message"]["content"].strip() | |
| print(f"[ocr] 页面 {page_index+1} OCR 提取成功 ({len(extracted_text)} 字符)") | |
| return extracted_text | |
| else: | |
| print(f"[ocr] OCR 异常 HTTP {r.status_code}") | |
| except Exception as e: | |
| print(f"[ocr] 页面 {page_index+1} 在线 OCR 失败: {e}") | |
| return "" | |
| # ══════════════════════════════════════════════════════════════════ | |
| # LlamaParse AI OCR 整本解析 | |
| # ══════════════════════════════════════════════════════════════════ | |
| def _ocr_pdf_via_llamaparse(pdf_path: Path, api_key_override: str = "") -> str: | |
| """ | |
| 用 LlamaCloud REST API 对整本 PDF 做 agentic OCR,返回 markdown 拼接的全文。 | |
| 不依赖 llama-cloud SDK,直接用 requests 调 HTTP API。 | |
| 需要环境变量 LLAMA_CLOUD_API_KEY。 | |
| """ | |
| api_key = api_key_override or os.environ.get("LLAMA_CLOUD_API_KEY") | |
| if not api_key: | |
| print("[ocr/llamaparse] 未配置 LLAMA_CLOUD_API_KEY") | |
| return "" | |
| headers = {"Authorization": f"Bearer {api_key}"} | |
| base = "https://api.cloud.llamaindex.ai/api/v2" | |
| pdf_name = pdf_path.name | |
| try: | |
| # 1. 上传文件 | |
| print(f"[ocr/llamaparse] 上传文件: {pdf_name}") | |
| with open(pdf_path, "rb") as f: | |
| upload_resp = requests.post( | |
| f"{base}/files", | |
| headers=headers, | |
| files={"file": (pdf_name, f, "application/pdf")}, | |
| data={"purpose": "parse"}, | |
| timeout=120, | |
| ) | |
| if upload_resp.status_code != 200: | |
| print(f"[ocr/llamaparse] 上传失败 HTTP {upload_resp.status_code}: {upload_resp.text[:200]}") | |
| return "" | |
| file_id = upload_resp.json()["id"] | |
| print(f"[ocr/llamaparse] 文件 ID: {file_id}") | |
| # 2. 提交解析任务 | |
| parse_resp = requests.post( | |
| f"{base}/parsing/parse", | |
| headers={**headers, "Content-Type": "application/json"}, | |
| json={ | |
| "file_id": file_id, | |
| "tier": "agentic", | |
| "version": "latest", | |
| "expand": ["markdown"], | |
| }, | |
| timeout=30, | |
| ) | |
| if parse_resp.status_code != 200: | |
| print(f"[ocr/llamaparse] 解析提交失败 HTTP {parse_resp.status_code}: {parse_resp.text[:200]}") | |
| return "" | |
| job_id = parse_resp.json()["id"] | |
| print(f"[ocr/llamaparse] Job ID: {job_id},轮询结果...") | |
| # 3. 轮询直到完成 | |
| for attempt in range(90): | |
| time.sleep(10) | |
| status_resp = requests.get(f"{base}/parsing/{job_id}/status", headers=headers, timeout=30) | |
| if status_resp.status_code != 200: | |
| continue | |
| status_data = status_resp.json() | |
| st = status_data.get("status", "") | |
| print(f"[ocr/llamaparse] 轮询 {attempt+1}/90: {st}") | |
| if st in ("SUCCESS", "completed"): | |
| # 取结果 | |
| result_resp = requests.get(f"{base}/parsing/{job_id}/result", headers=headers, timeout=30) | |
| if result_resp.ok: | |
| result_data = result_resp.json() | |
| pages = result_data.get("markdown", {}).get("pages", []) | |
| full_md = "\n\n".join(p.get("markdown", p.get("md", "")) for p in pages if p) | |
| if full_md: | |
| print(f"[ocr/llamaparse] 解析完成,{len(full_md)} 字符") | |
| return full_md | |
| # fallback: 取 text 格式 | |
| text_resp = requests.get(f"{base}/parsing/{job_id}/result?format=text", headers=headers, timeout=30) | |
| if text_resp.ok: | |
| t = text_resp.text | |
| if t and len(t.strip()) > 50: | |
| print(f"[ocr/llamaparse] 解析完成(text),{len(t)} 字符") | |
| return t | |
| print("[ocr/llamaparse] 结果为空") | |
| return "" | |
| elif st in ("ERROR", "CANCELLED", "failed"): | |
| print(f"[ocr/llamaparse] 解析失败: {status_data}") | |
| return "" | |
| print("[ocr/llamaparse] 轮询超时(15分钟)") | |
| return "" | |
| except Exception as e: | |
| print(f"[ocr/llamaparse] 异常: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return "" | |
| # ══════════════════════════════════════════════════════════════════ | |
| # 极速引擎:云端 API 5层混合翻译链(带日韩优化与多账号负载) | |
| # ══════════════════════════════════════════════════════════════════ | |
| def _translate_via_cloud_router(text: str, prev_source: str = "", prev_trans: str = "") -> str: | |
| lang = detect_japanese_korean(text) | |
| is_jk = lang in ("ja", "ko") | |
| context_prompt = "" | |
| if prev_source and prev_trans: | |
| context_prompt = ( | |
| f"### 上文翻译参考:\n" | |
| f"【原文】:{prev_source[-200:]}\n" | |
| f"【译文】:{prev_trans[-200:]}\n\n" | |
| ) | |
| system_prompt = ( | |
| "你是一位精通多国语言的资深学术翻译专家。请将下面的英文学术文本翻译成中文。\n" | |
| "## 翻译规则:\n" | |
| "1. 保持专业学术语言风格,用词准确,翻译自然流畅。不要直译,要符合现代中文的阅读习惯。\n" | |
| "2. 专有名词首次出现时请保留英文原文,格式如:“卷积神经网络 (Convolutional Neural Network, CNN)”。\n" | |
| "3. 人名、地名首次出现时使用中英对照。\n" | |
| "4. 严格保持原文段落和标点符号结构的完整,保留代码、公式、数字、年份。不要合并或拆分原文段落。\n" | |
| "5. 仅输出翻译结果,禁止输出任何多余的解释、导语或提示性字样。" | |
| ) | |
| user_content = f"{context_prompt}## 待翻译文本:\n{text}" | |
| # ────── 第一层:ModelScope 官方接口(多密钥轮询) ────── | |
| modelscope_keys = get_multi_api_keys("MODELSCOPE_API_KEY") | |
| if modelscope_keys: | |
| random.shuffle(modelscope_keys) | |
| ms_models = [ | |
| "Qwen/Qwen3-Coder-480B-A35B-Instruct", | |
| "MiniMax/MiniMax-M1-80k", | |
| "deepseek-ai/DeepSeek-V3.2", | |
| "MiniMax/MiniMax-M2.5", | |
| "deepseek-ai/DeepSeek-R1-0528", | |
| "Qwen/Qwen3-235B-A22B-Thinking-2507", | |
| "ZhipuAI/GLM-5", | |
| "Qwen/Qwen3.5-122B-A10B" | |
| ] | |
| if is_jk: | |
| ms_models = [m for m in ms_models if "qwen" in m.lower()] + [m for m in ms_models if "qwen" not in m.lower()] | |
| for model in ms_models: | |
| if not _quota_manager.increment(model): | |
| continue | |
| for key in modelscope_keys: | |
| try: | |
| url = "https://api-inference.modelscope.cn/v1/chat/completions" | |
| headers = {"Authorization": f"Bearer {key}", "Content-Type": "application/json"} | |
| body = { | |
| "model": model, | |
| "messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_content}], | |
| "temperature": 0.2, | |
| "max_tokens": 500 | |
| } | |
| r = requests.post(url, headers=headers, json=body, timeout=35) | |
| if r.status_code == 200: | |
| try: | |
| _choices = r.json().get("choices", []) | |
| if _choices and _choices[0].get("message", {}).get("content"): | |
| _quota_manager.record_success(model) | |
| print(f"[translate] 第一层 ModelScope 翻译成功: {model}") | |
| _text = _choices[0]["message"]["content"].strip() | |
| # 过滤 MiniMax <think> 思维链标签 | |
| _text = re.sub(r'<think>.*?</think>', '', _text, flags=re.DOTALL).strip() | |
| return _text | |
| else: | |
| print(f"[translate] ModelScope {model} HTTP 200 但内容为空") | |
| except Exception as pe: | |
| print(f"[translate] ModelScope {model} 响应解析异常: {pe}") | |
| else: | |
| print(f"[translate] ModelScope {model} HTTP {r.status_code}") | |
| except Exception as e: | |
| print(f"[translate] ModelScope {model} 发生异常: {e}") | |
| _quota_manager.record_fail(model) | |
| # ────── 第二层:主力层(Cerebras & Groq 多密钥均衡) ────── | |
| cerebras_keys = get_multi_api_keys("CEREBRAS_API_KEY") | |
| groq_keys = get_multi_api_keys("GROQ_API_KEY") | |
| random.shuffle(cerebras_keys) | |
| random.shuffle(groq_keys) | |
| # 2.1 Cerebras API | |
| for key in cerebras_keys: | |
| for model in ["gpt-oss-120b", "zai-glm-4.7"]: | |
| try: | |
| url = "https://api.cerebras.ai/v1/chat/completions" | |
| headers = {"Authorization": f"Bearer {key}", "Content-Type": "application/json"} | |
| body = { | |
| "model": model, | |
| "messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_content}], | |
| "temperature": 0.2 | |
| } | |
| r = requests.post(url, headers=headers, json=body, timeout=60) | |
| if r.status_code == 200: | |
| _c = r.json()["choices"][0]["message"]["content"] | |
| if _c and _c.strip(): | |
| print(f"[translate] 第二层 Cerebras 翻译成功: {model}") | |
| return _c.strip() | |
| except Exception as e: | |
| print(f"[translate] Cerebras 异常: {e}") | |
| # 2.2 Groq API | |
| for key in groq_keys: | |
| for model in ["openai/gpt-oss-120b", "llama-3.3-70b-versatile"]: | |
| try: | |
| url = "https://api.groq.com/openai/v1/chat/completions" | |
| headers = {"Authorization": f"Bearer {key}", "Content-Type": "application/json"} | |
| body = { | |
| "model": model, | |
| "messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_content}], | |
| "temperature": 0.2 | |
| } | |
| r = requests.post(url, headers=headers, json=body, timeout=60) | |
| if r.status_code == 200: | |
| _c = r.json()["choices"][0]["message"]["content"] | |
| if _c and _c.strip(): | |
| print(f"[translate] 第二层 Groq 翻译成功: {model}") | |
| return _c.strip() | |
| except Exception as e: | |
| print(f"[translate] Groq 异常: {e}") | |
| # 2.3 硅基流动 Hunyuan-MT-7B(翻译专用模型) | |
| sf_key = os.environ.get("SILICONFLOW_API_KEY") | |
| if sf_key: | |
| try: | |
| url = "https://api.siliconflow.cn/v1/chat/completions" | |
| headers = {"Authorization": f"Bearer {sf_key.strip()}", "Content-Type": "application/json"} | |
| body = { | |
| "model": "tencent/Hunyuan-MT-7B", | |
| "messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_content}], | |
| "temperature": 0.2 | |
| } | |
| r = requests.post(url, headers=headers, json=body, timeout=60) | |
| if r.status_code == 200: | |
| _c = r.json()["choices"][0]["message"]["content"] | |
| if _c and _c.strip(): | |
| print("[translate] 第二层 硅基流动 Hunyuan-MT-7B 翻译成功") | |
| return _c.strip() | |
| except Exception as e: | |
| print(f"[translate] 硅基流动 7B 异常: {e}") | |
| # ────── 第三层:免费补充层(NVIDIA / OpenRouter) ────── | |
| # 3.1 NVIDIA API 通道 | |
| nv_key = os.environ.get("NVIDIA_API_KEY") | |
| if nv_key: | |
| nv_models = [ | |
| "qwen/qwen3.5-397b-a17b", | |
| "qwen/qwen3-coder-480b-a35b-instruct", | |
| "qwen/qwen3.5-122b-a10b", | |
| "z-ai/glm-5.1", | |
| "nvidia/nemotron-3-super-120b-a12b" | |
| ] | |
| if is_jk: | |
| nv_models = [m for m in nv_models if "qwen" in m.lower()] + [m for m in nv_models if "qwen" not in m.lower()] | |
| for model in nv_models: | |
| try: | |
| url = "https://integrate.api.nvidia.com/v1/chat/completions" | |
| headers = {"Authorization": f"Bearer {nv_key.strip()}", "Content-Type": "application/json"} | |
| body = { | |
| "model": model, | |
| "messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_content}], | |
| "temperature": 0.2 | |
| } | |
| r = requests.post(url, headers=headers, json=body, timeout=60) | |
| if r.status_code == 200: | |
| _c = r.json()["choices"][0]["message"]["content"] | |
| if _c and _c.strip(): | |
| print(f"[translate] 第三层 NVIDIA 翻译成功: {model}") | |
| return _c.strip() | |
| except Exception as e: | |
| print(f"[translate] NVIDIA 异常: {e}") | |
| # 3.2 OpenRouter | |
| or_key = os.environ.get("OPENWOLF_OR_KEY") or os.environ.get("OPENROUTER_API_KEY") | |
| if or_key: | |
| or_models = [ | |
| "qwen/qwen3-coder:free", | |
| "meta-llama/llama-3.3-70b-instruct:free", | |
| "z-ai/glm-4.5-air:free", | |
| "nvidia/nemotron-3-super-120b-a12b:free", | |
| "qwen/qwen3-next-80b-a3b-instruct:free" | |
| ] | |
| if is_jk: | |
| or_models = [m for m in or_models if "qwen" in m.lower()] + [m for m in or_models if "qwen" not in m.lower()] | |
| for model in or_models: | |
| try: | |
| url = "https://openrouter.ai/api/v1/chat/completions" | |
| headers = {"Authorization": f"Bearer {or_key.strip()}", "Content-Type": "application/json"} | |
| body = { | |
| "model": model, | |
| "messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_content}], | |
| "temperature": 0.2 | |
| } | |
| r = requests.post(url, headers=headers, json=body, timeout=60) | |
| if r.status_code == 200: | |
| _c = r.json()["choices"][0]["message"]["content"] | |
| if _c and _c.strip(): | |
| print(f"[translate] 第三层 OpenRouter 翻译成功: {model}") | |
| return _c.strip() | |
| except Exception as e: | |
| print(f"[translate] OpenRouter 异常: {e}") | |
| # ────── 第四层:主力辅助层(Mistral & Opencode 多账号) ────── | |
| mistral_keys = get_multi_api_keys("MISTRAL_API_KEY") | |
| opencode_keys = get_multi_api_keys("OPENCODE_API_KEY") | |
| random.shuffle(mistral_keys) | |
| random.shuffle(opencode_keys) | |
| # 4.1 Mistral | |
| for key in mistral_keys: | |
| for model in ["mistral-large-latest", "mistral-medium-latest"]: | |
| try: | |
| url = "https://api.mistral.ai/v1/chat/completions" | |
| headers = {"Authorization": f"Bearer {key}", "Content-Type": "application/json"} | |
| body = { | |
| "model": model, | |
| "messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_content}], | |
| "temperature": 0.2 | |
| } | |
| r = requests.post(url, headers=headers, json=body, timeout=60) | |
| if r.status_code == 200: | |
| _c = r.json()["choices"][0]["message"]["content"] | |
| if _c and _c.strip(): | |
| print(f"[translate] 第四层 Mistral 翻译成功: {model}") | |
| return _c.strip() | |
| except Exception as e: | |
| print(f"[translate] Mistral 异常: {e}") | |
| # 4.2 Opencode | |
| for key in opencode_keys: | |
| for model in ["big-pickle", "nemotron-3-super-free", "deepseek-v4-flash-free"]: | |
| try: | |
| url = "https://opencode.ai/zen/v1/chat/completions" | |
| headers = {"Authorization": f"Bearer {key}", "Content-Type": "application/json"} | |
| body = { | |
| "model": model, | |
| "messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_content}], | |
| "temperature": 0.2 | |
| } | |
| r = requests.post(url, headers=headers, json=body, timeout=60) | |
| if r.status_code == 200: | |
| _c = r.json()["choices"][0]["message"]["content"] | |
| if _c and _c.strip(): | |
| print(f"[translate] 第四层 opencode 翻译成功: {model}") | |
| return _c.strip() | |
| except Exception as e: | |
| print(f"[translate] opencode 异常: {e}") | |
| # ────── 第五层:轻量兜底层(完全免费,顺序串联) ────── | |
| # 5.1 智谱免费 Flash | |
| zp_key = os.environ.get("ZHIPU_API_KEY") | |
| if zp_key: | |
| for model in ["glm-4.7-flash", "glm-4.6-flash", "GLM-Z1-Flash", "GLM-4-Flash"]: | |
| try: | |
| url = "https://open.bigmodel.cn/api/paas/v4/chat/completions" | |
| headers = {"Authorization": f"Bearer {zp_key.strip()}", "Content-Type": "application/json"} | |
| body = { | |
| "model": model, | |
| "messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_content}], | |
| "temperature": 0.2 | |
| } | |
| r = requests.post(url, headers=headers, json=body, timeout=60) | |
| if r.status_code == 200: | |
| _c = r.json()["choices"][0]["message"]["content"] | |
| if _c and _c.strip(): | |
| print(f"[translate] 第五层 智谱 翻译成功: {model}") | |
| return _c.strip() | |
| except Exception as e: | |
| print(f"[translate] 智谱 异常: {e}") | |
| # 5.2 硅基流动小模型完全免费通道 | |
| if sf_key: | |
| sf_free_models = [ | |
| "deepseek-ai/DeepSeek-R1-0528-Qwen3-8B", | |
| "Qwen/Qwen3.5-4B", | |
| "Qwen/Qwen3-8B", | |
| "THUDM/GLM-Z1-9B-0414", | |
| "THUDM/GLM-4-9B-0414" | |
| ] | |
| if is_jk: | |
| sf_free_models = [m for m in sf_free_models if "qwen" in m.lower()] + [m for m in sf_free_models if "qwen" not in m.lower()] | |
| for model in sf_free_models: | |
| try: | |
| url = "https://api.siliconflow.cn/v1/chat/completions" | |
| headers = {"Authorization": f"Bearer {sf_key.strip()}", "Content-Type": "application/json"} | |
| body = { | |
| "model": model, | |
| "messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_content}], | |
| "temperature": 0.2 | |
| } | |
| r = requests.post(url, headers=headers, json=body, timeout=60) | |
| if r.status_code == 200: | |
| _c = r.json()["choices"][0]["message"]["content"] | |
| if _c and _c.strip(): | |
| print(f"[translate] 第五层 硅基流动 翻译成功: {model}") | |
| return _c.strip() | |
| except Exception as e: | |
| print(f"[translate] 硅基流动免费通道异常: {e}") | |
| return "" | |
| # ══════════════════════════════════════════════════════════════════ | |
| # 本地 GGUF 离线安全兜底推理引擎(轻量直观版 Prompt,保障 1.8B 稳定翻译) | |
| # ══════════════════════════════════════════════════════════════════ | |
| def _get_llama(): | |
| global _llama_model | |
| if _llama_model is not None: | |
| return _llama_model | |
| with _llama_lock: | |
| if _llama_model is not None: | |
| return _llama_model | |
| model_path = "/app/models/translate/HY-MT1.5-1.8B-Q4_K_M.gguf" | |
| # 安全防御:如果模型文件在写入中或不存在,循环等待 | |
| for _i in range(60): | |
| if os.path.isfile(model_path) and os.path.getsize(model_path) > 100 * 1024 * 1024: | |
| break | |
| print(f"[llama] 正在等待本地兜底模型准备就绪... {_i}s") | |
| time.sleep(1) | |
| if not os.path.isfile(model_path): | |
| model_path = "/app/models/translate/HY-MT1.5-1.8B-Q8_0.gguf" | |
| for _i in range(10): | |
| if os.path.isfile(model_path) and os.path.getsize(model_path) > 100 * 1024 * 1024: | |
| break | |
| time.sleep(1) | |
| if not os.path.isfile(model_path): | |
| raise FileNotFoundError("GGUF model file not found in translate directory") | |
| from llama_cpp import Llama | |
| result = [None] | |
| error = [None] | |
| done = threading.Event() | |
| # 动态计算最适合您 CPU 算力的物理线程数 | |
| num_cores = os.cpu_count() or 4 | |
| optimal_threads = max(1, min(4, num_cores)) | |
| def _load(): | |
| try: | |
| print(f"[llama] Loading HY-MT1.5 with {optimal_threads} threads...") | |
| t0 = time.time() | |
| result[0] = Llama(model_path=model_path, n_ctx=8192, | |
| n_threads=optimal_threads, n_gpu_layers=0, verbose=False, | |
| use_mmap=True, use_mlock=False) | |
| print(f"[llama] Loaded in {time.time()-t0:.1f}s") | |
| except Exception as e: | |
| error[0] = e | |
| finally: | |
| done.set() | |
| t = threading.Thread(target=_load, daemon=True) | |
| t.start() | |
| if not done.wait(timeout=90): | |
| raise TimeoutError("GGUF model loading timed out (90s)") | |
| if error[0]: | |
| raise error[0] | |
| _llama_model = result[0] | |
| return _llama_model | |
| def _translate_chunk_local(text: str) -> str: | |
| """ | |
| 轻量直译版 Prompt,对 1.8B 小模型深度优化,100%老实翻译 | |
| """ | |
| llm = _get_llama() | |
| _max = max(64, min(1024, int(len(text) * 1.5))) | |
| _prompt = ( | |
| "将以下学术英文文本完整翻译成中文。要求:\n" | |
| "1. 保持专业学术语言风格。\n" | |
| "2. 专有名词首次出现时保留英文原文。\n" | |
| "3. 人名和地名不翻译,直接保留原文。\n" | |
| "4. 严格保持原有段落和标点结构不变。\n" | |
| "5. 保留公式、数字和年份。\n" | |
| "6. 只输出翻译好的中文译文,禁止输出任何拼音、多余的解释或说明文字。\n\n" | |
| f"{text}" | |
| ) | |
| with _infer_lock: | |
| out = llm.create_chat_completion( | |
| messages=[{"role": "user", "content": _prompt}], | |
| max_tokens=_max, temperature=0.1, | |
| ) | |
| return out["choices"][0]["message"]["content"].strip() | |
| async def debug_model(): | |
| q4_path = "/app/models/translate/HY-MT1.5-1.8B-Q4_K_M.gguf" | |
| q8_path = "/app/models/translate/HY-MT1.5-1.8B-Q8_0.gguf" | |
| q4_exists = os.path.isfile(q4_path) | |
| q8_exists = os.path.isfile(q8_path) | |
| result = { | |
| "q4_exists": q4_exists, | |
| "q8_exists": q8_exists, | |
| "q4_size_gb": round(os.path.getsize(q4_path) / 1024**3, 2) if q4_exists else 0, | |
| "q8_size_gb": round(os.path.getsize(q8_path) / 1024**3, 2) if q8_exists else 0, | |
| "llama_loaded": _llama_model is not None, | |
| } | |
| try: | |
| import llama_cpp | |
| result["llama_cpp_version"] = llama_cpp.__version__ | |
| except ImportError: | |
| result["llama_cpp_version"] = None | |
| return JSONResponse(result) | |
| # ══════════════════════════════════════════════════════════════════ | |
| # 接口:防弹版异步翻译任务启动(带懒加载 OCR 提取与 <think> 过滤) | |
| # ══════════════════════════════════════════════════════════════════ | |
| async def api_job_cancel(request: Request): | |
| """取消某个 chat_id 的所有翻译任务""" | |
| try: | |
| body = await request.json() | |
| chat_id = str(body.get("chat_id") or "") | |
| if chat_id: | |
| _cancelled_tasks.add(chat_id) | |
| print(f"[cancel] 已标记取消: chat_id={chat_id}") | |
| return {"ok": True} | |
| return {"ok": False, "error": "chat_id required"} | |
| except Exception as e: | |
| return {"ok": False, "error": str(e)} | |
| async def api_job_start(request: Request): | |
| try: | |
| body = await request.json() | |
| except Exception as e: | |
| return {"ok": False, "error": f"JSON 解析失败: {e}"} | |
| try: | |
| task_id = str(uuid.uuid4()) | |
| _translate_task_write(task_id, {"status": "processing", "result": None}) | |
| def _do_work(t_id, payload): | |
| try: | |
| _file_id = payload.get("fileId") or payload.get("file_id") | |
| _file_path = payload.get("file_path") or payload.get("filePath") | |
| _dl_url = payload.get("download_url") or payload.get("downloadUrl") | |
| _r2_url = payload.get("r2_download_url") or payload.get("r2DownloadUrl") | |
| _orig_fn = payload.get("fileName") or payload.get("filename") or "document.pdf" | |
| _orig_ext = _orig_fn.rsplit(".", 1)[-1].lower() if "." in _orig_fn else "pdf" | |
| _is_local_only = payload.get("is_local_only", False) | |
| _ci = payload.get("chunk_index", 0) | |
| if _ci == -1 or _ci is None: | |
| _ci = 0 | |
| _chat_id_str = str(payload.get("chat_id") or "default") | |
| _cancelled_tasks.discard(_chat_id_str) # ★ 新任务启动时清除旧的取消标志 | |
| _context_dir = Path("/app/.context_cache") / _chat_id_str | |
| _context_dir.mkdir(parents=True, exist_ok=True) | |
| _chunks_cache_file = _context_dir / "chunks_list.json" | |
| _meta_cache_file = _context_dir / "pdf_metadata.json" | |
| _chunks = [] | |
| _is_scanned = False | |
| _total_pages = 0 | |
| if _meta_cache_file.is_file(): | |
| try: | |
| _meta = json.loads(_meta_cache_file.read_text(encoding="utf-8")) | |
| _is_scanned = _meta.get("is_scanned", False) | |
| _total_pages = _meta.get("total_pages", 0) | |
| except Exception: | |
| pass | |
| # ★ 每段开始前检查取消标志 | |
| if _chat_id_str in _cancelled_tasks: | |
| print(f"[translate] 检测到取消标志,中止翻译: chat_id={_chat_id_str}") | |
| _translate_task_write(t_id, {"status": "error", "result": "任务已取消"}) | |
| _cancelled_tasks.discard(_chat_id_str) | |
| return | |
| # 电子文本分段缓存直调 | |
| if _ci > 0 and not _is_scanned and _chunks_cache_file.is_file(): | |
| try: | |
| _chunks = json.loads(_chunks_cache_file.read_text(encoding="utf-8")) | |
| print(f"[translate] 命中切片缓存。当前分段: {_ci + 1}/{len(_chunks)}") | |
| except Exception as _ce: | |
| print(f"[translate] 载入分段缓存失败: {_ce}") | |
| # 初始解析与环境拉取(缓存丢失时 resume 也重新下载) | |
| if not _chunks: | |
| print(f"[download] 开始定位/下载文件: file_id={_file_id}, file_path={_file_path}") | |
| _downloaded_local_path = None | |
| # 第一层:从 R2 下载 | |
| if _r2_url: | |
| try: | |
| import requests as _rt | |
| import uuid as _uuid | |
| _r = _rt.get(_r2_url, timeout=120, stream=True) | |
| if _r.status_code == 200: | |
| _local = Path("/app") / f"inputs/{_uuid.uuid4().hex}_{_orig_fn}" | |
| _local.parent.mkdir(parents=True, exist_ok=True) | |
| with open(_local, "wb") as _f: | |
| for _chunk in _r.iter_content(chunk_size=65536): | |
| _f.write(_chunk) | |
| _downloaded_local_path = _local | |
| else: | |
| print(f"[download] R2 HTTP {_r.status_code}: {_r2_url}") | |
| except Exception as _e: | |
| print(f"[download] R2 异常: {_e}") | |
| # 第二层:从 download_url 下载 | |
| if _dl_url and not (_downloaded_local_path and _downloaded_local_path.is_file()): | |
| try: | |
| import requests as _rt | |
| import uuid as _uuid | |
| _h = {} | |
| _gh_pat = os.environ.get("OPENWOLF_PAT") or os.environ.get("GITHUB_PAT") or os.environ.get("GITHUB_TOKEN") or "" | |
| if "api.github.com" in _dl_url and _gh_pat: | |
| _h["Authorization"] = f"Bearer {_gh_pat}" | |
| _h["Accept"] = "application/vnd.github.raw" | |
| _r = _rt.get(_dl_url, headers=_h, timeout=120) | |
| if _r.status_code == 200: | |
| _local = Path("/app") / f"inputs/{_uuid.uuid4().hex}.{_orig_ext}" | |
| _local.parent.mkdir(parents=True, exist_ok=True) | |
| _local.write_bytes(_r.content) | |
| _downloaded_local_path = _local | |
| except Exception as _e: | |
| print(f"[download] download_url 异常: {_e}") | |
| # 第三层:根据 file_path 获取 GitHub 文件 | |
| if _file_path and not (_downloaded_local_path and _downloaded_local_path.is_file()): | |
| _repo_path = _file_path | |
| _local_check = Path("/app") / _repo_path | |
| if _local_check.is_file(): | |
| _downloaded_local_path = _local_check | |
| elif Path(_repo_path).is_file(): | |
| _downloaded_local_path = Path(_repo_path) | |
| else: | |
| _gh_repo = os.environ.get("GITHUB_REPO", "hughyonng/OpenWolf") | |
| _gh_pat = os.environ.get("OPENWOLF_PAT") or os.environ.get("GITHUB_PAT") or os.environ.get("GITHUB_TOKEN") or "" | |
| if _gh_pat: | |
| try: | |
| import requests as _rt | |
| import uuid as _uuid | |
| _u = f"https://api.github.com/repos/{_gh_repo}/contents/{_repo_path}" | |
| _h = {"Authorization": f"Bearer {_gh_pat}", "Accept": "application/vnd.github.raw"} | |
| _r = _rt.get(_u, headers=_h, timeout=120) | |
| if _r.status_code == 200: | |
| _local = Path("/app") / f"inputs/{_uuid.uuid4().hex}.{_orig_ext}" | |
| _local.parent.mkdir(parents=True, exist_ok=True) | |
| _local.write_bytes(_r.content) | |
| _downloaded_local_path = _local | |
| except Exception as _e: | |
| print(f"[download] GitHub API 异常: {_e}") | |
| # 第四层:从 Telegram 兜底下载 | |
| if _file_id and not (_downloaded_local_path and _downloaded_local_path.is_file()): | |
| try: | |
| _token = os.environ.get("TELEGRAM_BOT_TOKEN", "") | |
| if _token: | |
| import requests as _rt | |
| import uuid as _uuid | |
| _mr = _rt.get(f"https://api.telegram.org/bot{_token}/getFile?file_id={_file_id}", timeout=30) | |
| _fd = _mr.json().get("result", {}) if _mr.ok else {} | |
| _fp = _fd.get("file_path", "") | |
| if _mr.ok and _fp: | |
| _dl = _rt.get(f"https://api.telegram.org/file/bot{_token}/{_fp}", timeout=300, stream=True) | |
| if _dl.ok: | |
| _local = Path("/app") / f"inputs/{_uuid.uuid4().hex}_{_fp.split('/')[-1]}" | |
| _local.parent.mkdir(parents=True, exist_ok=True) | |
| with open(_local, "wb") as _f: | |
| for _chunk in _dl.iter_content(chunk_size=65536): | |
| _f.write(_chunk) | |
| _downloaded_local_path = _local | |
| except Exception as _e: | |
| print(f"[download] Telegram 异常: {_e}") | |
| if not _downloaded_local_path or not _downloaded_local_path.is_file(): | |
| raise ValueError("无法在所有防护层中下载或定位待翻译文档") | |
| _fixed_path = _context_dir / f"source_document.{_orig_ext}" | |
| import shutil | |
| shutil.copy2(_downloaded_local_path, _fixed_path) | |
| _downloaded_local_path = _fixed_path | |
| # 判断 PDF 属性 | |
| if _orig_ext == "pdf": | |
| try: | |
| import fitz | |
| doc = fitz.open(_downloaded_local_path) | |
| _total_pages = len(doc) | |
| sample_text = "" | |
| for p_idx in range(min(3, _total_pages)): | |
| sample_text += doc[p_idx].get_text() or "" | |
| if len(sample_text.strip()) < 100: | |
| if _is_local_only: | |
| _is_scanned = False | |
| print("[translate] 安全本地模式激活:强制关闭在线 OCR,采用电子文本降级读取") | |
| else: | |
| _is_scanned = True | |
| print("[translate] 检测到 PDF 为扫描版图片件,自动开启 Lazy-OCR 通道") | |
| else: | |
| _is_scanned = False | |
| print("[translate] 检测到 PDF 为电子文本版,直接提取") | |
| except Exception as e: | |
| _is_scanned = False | |
| print(f"[translate] 预判定 PDF 属性异常,降级到电子版读取: {e}") | |
| else: | |
| _is_scanned = False | |
| try: | |
| _meta_cache_file.write_text(json.dumps({ | |
| "is_scanned": _is_scanned, | |
| "total_pages": _total_pages, | |
| "file_ext": _orig_ext, | |
| "file_name": _orig_fn, | |
| }, ensure_ascii=False), encoding="utf-8") | |
| except Exception as e: | |
| print(f"[meta_cache] 写入缓存异常: {e}") | |
| # 电子版读取与分段(优化大分段支持 5000 汉字) | |
| if not _is_scanned: | |
| import pdfplumber as _pp | |
| _full_text = "" | |
| if _orig_ext == "pdf": | |
| with _pp.open(_downloaded_local_path) as _p: | |
| _full_text = "\n".join(page.extract_text() or "" for page in _p.pages) | |
| elif _orig_ext in ("txt", "md", "csv", "json"): | |
| with open(_downloaded_local_path, "r", encoding="utf-8", errors="ignore") as _f: | |
| _full_text = _f.read() | |
| elif _orig_ext in ("docx",): | |
| import docx as _dx | |
| _d = _dx.Document(_downloaded_local_path) | |
| _full_text = "\n".join(p.text for p in _d.paragraphs) | |
| elif _orig_ext == "epub": | |
| import zipfile, xml.etree.ElementTree as _ET | |
| with zipfile.ZipFile(_downloaded_local_path, "r") as _z: | |
| for _name in _z.namelist(): | |
| if _name.endswith((".xhtml", ".html", ".htm")): | |
| try: | |
| _root = _ET.fromstring(_z.read(_name)) | |
| for _elem in _root.iter(): | |
| if _elem.text: _full_text += _elem.text.strip() + " " | |
| if _elem.tail: _full_text += _elem.tail.strip() + " " | |
| _full_text += "\n\n" | |
| except Exception: | |
| pass | |
| if not _full_text.strip(): | |
| raise ValueError("文本提取为空,请检查文件是否加密") | |
| # 合并行内折行(用空格替代 newline,解决标题粘连) | |
| import re as _re | |
| _full_text = _re.sub(r'(?<![。.!?!?」』)\)"])\n(?!\n)', ' ', _full_text) | |
| # 语义切分(目标尺寸为 16000 字符,对应 ~3000 单词/5000汉字) | |
| _chunks = semantic_split(_full_text, target_size=8000) | |
| try: | |
| _chunks_cache_file.write_text(json.dumps(_chunks, ensure_ascii=False), encoding="utf-8") | |
| print(f"[translate] 电子书分段写入成功,总段数: {len(_chunks)}") | |
| except Exception as _se: | |
| print(f"[translate] 写入切片缓存失败: {_se}") | |
| # ────── 执行分段调度 ────── | |
| # 情况 1:扫描版 PDF | |
| if _is_scanned: | |
| _total_chunks = int((_total_pages + PAGES_PER_CHUNK - 1) / PAGES_PER_CHUNK) | |
| if _ci >= _total_chunks: | |
| result_payload = { | |
| "translated_text": "🎉 本书已通过在线 OCR 全部翻译完毕!", | |
| "has_more": False, | |
| "chunk_index": _ci, | |
| "total_chunks": _total_chunks, | |
| "file_path": _file_path | |
| } | |
| else: | |
| start_page = _ci * PAGES_PER_CHUNK | |
| end_page = min(start_page + PAGES_PER_CHUNK, _total_pages) | |
| _chunk_raw_text = "" | |
| _fixed_path = _context_dir / f"source_document.{_orig_ext}" | |
| print(f"[ocr] 正在提取 scanned PDF 第 {start_page+1} 至 {end_page} 页...") | |
| for p_idx in range(start_page, end_page): | |
| page_text = _ocr_page_via_siliconflow(_fixed_path, p_idx) | |
| if page_text: | |
| _chunk_raw_text += page_text + "\n\n" | |
| if not _chunk_raw_text.strip(): | |
| raise ValueError(f"在线 OCR 未能在第 {start_page+1}~{end_page} 页识别到任何有效字符") | |
| _prev_source = "" | |
| _prev_trans = "" | |
| _prev_src_file = _context_dir / f"src_{_ci - 1}.txt" | |
| _prev_trs_file = _context_dir / f"trans_{_ci - 1}.txt" | |
| if _ci > 0 and _prev_src_file.is_file() and _prev_trs_file.is_file(): | |
| _prev_source = _prev_src_file.read_text(encoding="utf-8", errors="ignore") | |
| _prev_trans = _prev_trs_file.read_text(encoding="utf-8", errors="ignore") | |
| # 如果是 local_only 模式,强行禁用在线大模型 | |
| if _is_local_only: | |
| _tr = _translate_chunk_local(_chunk_raw_text) | |
| else: | |
| _tr = _translate_via_cloud_router(_chunk_raw_text, _prev_source, _prev_trans) | |
| if not _tr: | |
| print("[translate] 在线路由空转,降级本地 GGUF 直译...") | |
| _tr = _translate_chunk_local(_chunk_raw_text) | |
| # 🌟 优化:在保存和传递前,利用正则清洗大模型特有的思维链(<think>...</think>) | |
| _tr = clean_think_tags(_tr) | |
| _curr_src_file = _context_dir / f"src_{_ci}.txt" | |
| _curr_trs_file = _context_dir / f"trans_{_ci}.txt" | |
| _curr_src_file.write_text(_chunk_raw_text, encoding="utf-8") | |
| _curr_trs_file.write_text(_tr, encoding="utf-8") | |
| for _f in _context_dir.glob("*.txt"): | |
| try: | |
| _f_name = _f.name | |
| if _f_name.startswith("src_") or _f_name.startswith("trans_"): | |
| _f_idx = int(_f_name.split("_")[1].split(".")[0]) | |
| if _f_idx < _ci - 1: | |
| _f.unlink() | |
| except: | |
| pass | |
| _hm = (_ci + 1) < _total_chunks | |
| result_payload = { | |
| "translated_text": _tr, | |
| "has_more": _hm, | |
| "chunk_index": _ci, | |
| "total_chunks": _total_chunks, | |
| "file_path": _file_path | |
| } | |
| # 情况 2:电子版 | |
| else: | |
| if not _chunks and _chunks_cache_file.is_file(): | |
| try: | |
| _chunks = json.loads(_chunks_cache_file.read_text(encoding="utf-8")) | |
| except Exception: | |
| pass | |
| _total_chunks = len(_chunks) if _chunks else 1 | |
| if _ci >= _total_chunks or not _chunks: | |
| result_payload = { | |
| "translated_text": "🎉 本书已翻译完毕!", | |
| "has_more": False, | |
| "chunk_index": _ci, | |
| "total_chunks": _total_chunks, | |
| "file_path": _file_path | |
| } | |
| else: | |
| _chunk_to_trans = _chunks[_ci] | |
| # ★ 每次翻译前检查取消标志 | |
| if _chat_id_str in _cancelled_tasks: | |
| raise ValueError("任务已被用户取消") | |
| _prev_source = "" | |
| _prev_trans = "" | |
| _prev_src_file = _context_dir / f"src_{_ci - 1}.txt" | |
| _prev_trs_file = _context_dir / f"trans_{_ci - 1}.txt" | |
| if _ci > 0 and _prev_src_file.is_file() and _prev_trs_file.is_file(): | |
| _prev_source = _prev_src_file.read_text(encoding="utf-8", errors="ignore") | |
| _prev_trans = _prev_trs_file.read_text(encoding="utf-8", errors="ignore") | |
| # 安全模式检测:如果 local_only 激活,强行跳过在线接口 | |
| if _is_local_only: | |
| print("[translate] 安全本地模式激活:强行绕过所有在线 AI 接口,仅执行本地 GGUF 翻译") | |
| _tr = _translate_chunk_local(_chunk_to_trans) | |
| else: | |
| _tr = _translate_via_cloud_router(_chunk_to_trans, _prev_source, _prev_trans) | |
| if not _tr: | |
| print("[translate] 在线接口空转,降级本地 GGUF 兜底翻译...") | |
| _tr = _translate_chunk_local(_chunk_to_trans) | |
| # 🌟 优化:在保存和传递前,利用正则清洗大模型特有的思维链(<think>...</think>) | |
| _tr = clean_think_tags(_tr) | |
| _curr_src_file = _context_dir / f"src_{_ci}.txt" | |
| _curr_trs_file = _context_dir / f"trans_{_ci}.txt" | |
| _curr_src_file.write_text(_chunk_to_trans, encoding="utf-8") | |
| _curr_trs_file.write_text(_tr, encoding="utf-8") | |
| for _f in _context_dir.glob("*.txt"): | |
| try: | |
| _f_name = _f.name | |
| if _f_name.startswith("src_") or _f_name.startswith("trans_"): | |
| _f_idx = int(_f_name.split("_")[1].split(".")[0]) | |
| if _f_idx < _ci - 1: | |
| _f.unlink() | |
| except: | |
| pass | |
| _hm = (_ci + 1) < _total_chunks | |
| result_payload = { | |
| "translated_text": _tr, | |
| "has_more": _hm, | |
| "chunk_index": _ci, | |
| "total_chunks": _total_chunks, | |
| "file_path": _file_path | |
| } | |
| _translate_task_write(t_id, {"status": "done", "result": json.dumps(result_payload, ensure_ascii=False)}) | |
| # ★ 回调推送(和 OCR 一样) | |
| _cb_url = payload.get("callback_url") or "" | |
| if _cb_url: | |
| try: | |
| import requests as _rq | |
| _rq.post(_cb_url, json={ | |
| "ok": True, | |
| "task_id": t_id, | |
| "status": "done", | |
| "result": json.dumps(result_payload, ensure_ascii=False), | |
| "bus_id": payload.get("bus_id") or "", | |
| "chat_id": str(payload.get("chat_id") or ""), | |
| "chunk_index": result_payload.get("chunk_index", 0), | |
| "has_more": result_payload.get("has_more", False), | |
| "total_chunks": result_payload.get("total_chunks", 1), | |
| "file_path": result_payload.get("file_path") or _file_path or "", | |
| "file_name": _orig_fn, | |
| }, timeout=15) | |
| print(f"[api_job_start] 回调推送成功: {_cb_url}") | |
| except Exception as _ce: | |
| print(f"[api_job_start] 回调推送失败: {_ce}") | |
| except Exception as e: | |
| import traceback | |
| print(f"[api_job_start] 翻译子线程异常: {e}") | |
| traceback.print_exc() | |
| _translate_task_write(t_id, {"status": "error", "result": f"模型推理报错: {e}"}) | |
| _cb_url = payload.get("callback_url") or "" | |
| if _cb_url: | |
| try: | |
| import requests as _rq | |
| _rq.post(_cb_url, json={"ok": False, "task_id": t_id, "error": str(e), | |
| "bus_id": payload.get("bus_id") or "", | |
| "chat_id": str(payload.get("chat_id") or "")}, timeout=15) | |
| except Exception: pass | |
| _translate_pool.submit(_do_work, task_id, body) | |
| return {"ok": True, "task_id": task_id} | |
| except Exception as e: | |
| return {"ok": False, "error": f"路由层报错: {e}"} | |
| async def api_job_check(task_id: str): | |
| """防弹版:替换原有的 /translate/check""" | |
| try: | |
| task = _translate_task_read(task_id) | |
| if not task: | |
| return {"ok": False, "status": "error", "result": "任务ID不存在"} | |
| if task["status"] in ("done", "error"): | |
| result_copy = task.copy() | |
| return {"ok": True, "status": result_copy["status"], "result": result_copy["result"]} | |
| return {"ok": True, "status": "processing"} | |
| except Exception as e: | |
| return {"ok": False, "status": "error", "result": f"检查报错: {e}"} | |
| # ── 以下原有文档分析及其他路由逻辑保持完整 ── | |
| async def analyze_doc(request: Request): | |
| try: | |
| body = await request.json() | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Invalid JSON") | |
| url = body.get("url", "") | |
| question = body.get("question", "请分析这份文档的内容") | |
| max_chars = int(body.get("max_chars", 50000)) | |
| if not url: | |
| return {"ok": False, "error": "url required"} | |
| import requests as _req | |
| import uuid as _uuid | |
| from pathlib import Path | |
| resp = _req.get(url, timeout=300, stream=True) | |
| if resp.status_code != 200: | |
| return {"ok": False, "error": f"下载失败 HTTP {resp.status_code}"} | |
| local_path = Path("/app") / f"inputs/{_uuid.uuid4().hex}.pdf" | |
| local_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(local_path, "wb") as f: | |
| for chunk in resp.iter_content(chunk_size=65536): | |
| f.write(chunk) | |
| import pdfplumber | |
| text = "" | |
| with pdfplumber.open(local_path) as p: | |
| for page in p.pages: | |
| t = page.extract_text() | |
| if t: | |
| text += t + "\n" | |
| try: | |
| local_path.unlink() | |
| except Exception: | |
| pass | |
| if not text.strip(): | |
| return {"ok": False, "error": "无法提取文本内容"} | |
| doc_text = text[:max_chars] | |
| return {"ok": True, "result": "分析完成"} | |
| async def analyze_doc_start(request: Request): | |
| try: | |
| body = await request.json() | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Invalid JSON") | |
| url = body.get("url", "") | |
| question = body.get("question", "请分析这份文档的内容") | |
| max_chars = int(body.get("max_chars", 50000)) | |
| if not url: | |
| return {"ok": False, "error": "url required"} | |
| task_id = str(uuid.uuid4()) | |
| _analyze_task_write(task_id, {"status": "processing", "result": None}) | |
| _analyze_pool.submit(_do_analyze_async, task_id, url, question, max_chars) | |
| return {"ok": True, "task_id": task_id} | |
| async def analyze_text_start(request: Request): | |
| try: | |
| body = await request.json() | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Invalid JSON") | |
| text = body.get("text", "") | |
| question = body.get("question", "") | |
| if not text or not question: | |
| return {"ok": False, "error": "text and question required"} | |
| task_id = str(uuid.uuid4()) | |
| _analyze_task_write(task_id, {"status": "processing", "result": None}) | |
| _analyze_pool.submit(_do_analyze_text_async, task_id, text, question) | |
| return {"ok": True, "task_id": task_id} | |
| async def analyze_text_check(task_id: str): | |
| return await analyze_doc_check(task_id) | |
| async def analyze_doc_check(task_id: str): | |
| task = _analyze_task_read(task_id) | |
| if not task: | |
| return {"ok": False, "status": "error", "result": "任务ID不存在"} | |
| if task["status"] in ("done", "error"): | |
| result_copy = task.copy() | |
| resp = {"ok": True, "status": result_copy["status"], "result": result_copy["result"]} | |
| if result_copy.get("doc_text"): | |
| resp["doc_text"] = result_copy["doc_text"] | |
| return resp | |
| return {"ok": True, "status": "processing"} | |
| def _do_analyze_async(task_id: str, url: str, question: str, max_chars: int): | |
| import requests as _req | |
| import uuid as _uuid | |
| from pathlib import Path | |
| try: | |
| resp = _req.get(url, timeout=300, stream=True) | |
| if resp.status_code != 200: | |
| _analyze_task_write(task_id, {"status": "error", "result": f"下载失败 HTTP {resp.status_code}"}) | |
| return | |
| local_path = Path("/app") / f"inputs/{_uuid.uuid4().hex}.pdf" | |
| local_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(local_path, "wb") as f: | |
| for chunk in resp.iter_content(chunk_size=65536): | |
| f.write(chunk) | |
| import pdfplumber | |
| text = "" | |
| with pdfplumber.open(local_path) as p: | |
| for page in p.pages: | |
| t = page.extract_text() | |
| if t: | |
| text += t + "\n" | |
| try: | |
| local_path.unlink() | |
| except Exception: | |
| pass | |
| if not text.strip(): | |
| _analyze_task_write(task_id, {"status": "error", "result": "无法提取文本内容"}) | |
| return | |
| doc_text = text[:max_chars] | |
| _do_analyze_text_async(task_id, doc_text, question) | |
| except Exception as e: | |
| _analyze_task_write(task_id, {"status": "error", "result": f"分析失败: {e}"}) | |
| def _do_analyze_text_async(task_id: str, doc_text: str, question: str): | |
| import requests as _req | |
| import concurrent.futures as _cf | |
| try: | |
| prompt = f"以下是文档内容:\n\n{doc_text}\n\n---\n\n用户问题:{question}" | |
| _analyze_task_write(task_id, {"status": "done", "result": "分析完成", "doc_text": doc_text}) | |
| except Exception as e: | |
| _analyze_task_write(task_id, {"status": "error", "result": f"分析失败: {e}"}) | |
| async def task_start(request: Request): | |
| try: | |
| body = await request.json() | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Invalid JSON") | |
| task_text = body.get("task", "") | |
| chat_id = body.get("chat_id", "") | |
| task_type = body.get("task_type") or None | |
| history = body.get("history", []) | |
| if not task_text: | |
| return {"ok": False, "error": "task required"} | |
| task_id = str(uuid.uuid4()) | |
| _task_task_write(task_id, {"status": "processing", "result": None}) | |
| _task_pool.submit(_do_task_async, task_id, task_text, str(chat_id), task_type, history) | |
| return {"ok": True, "task_id": task_id} | |
| async def task_check(task_id: str): | |
| task = _task_task_read(task_id) | |
| if not task: | |
| return {"ok": False, "status": "error", "result": "任务ID不存在"} | |
| if task["status"] in ("done", "error"): | |
| result_copy = task.copy() | |
| return {"ok": True, "status": result_copy["status"], "result": result_copy["result"]} | |
| return {"ok": True, "status": "processing"} | |
| def _do_task_async(task_id: str, task_text: str, chat_id: str, task_type: str = None, history: list = None): | |
| if history is None: | |
| history = [] | |
| try: | |
| from scripts.ai_agent import run_agent_task | |
| result = run_agent_task(task_text, history, None, chat_id, "consumer", task_type=task_type) | |
| _task_task_write(task_id, {"status": "done", "result": str(result)}) | |
| except Exception as e: | |
| _task_task_write(task_id, {"status": "error", "result": f"处理失败: {e}"}) | |
| async def skill_search(request: Request): | |
| q = request.query_params.get("q", "").strip().lower() | |
| if not q: | |
| return JSONResponse([]) | |
| idx = _get_skill_index() | |
| results = [] | |
| for s in idx.get("skills", []): | |
| if q in s.get("name", "").lower() or q in s.get("description", "").lower(): | |
| results.append({"id": s["id"], "name": s["name"], "description": s.get("description", "")[:200]}) | |
| return JSONResponse(results) | |
| async def skill_view(request: Request): | |
| name = request.query_params.get("name", "").strip().lower() | |
| if not name: | |
| return JSONResponse({"error": "name required"}, status_code=400) | |
| idx = _get_skill_index() | |
| for s in idx.get("skills", []): | |
| sid = s.get("id", "").lower() | |
| if name in sid or name in s.get("name", "").lower(): | |
| readme_url = f"https://raw.githubusercontent.com/hughyonng/OpenWolf/refs/heads/main/skills/library/{sid}/README.md" | |
| try: | |
| r = requests.get(readme_url, timeout=10) | |
| if r.ok: | |
| return JSONResponse({"name": s["name"], "content": r.text[:2000]}) | |
| except Exception: | |
| pass | |
| return JSONResponse({"error": "not found"}, status_code=404) | |