# router_proxy.py from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse, JSONResponse, Response from sqlalchemy.orm import Session from pydantic import BaseModel, Field, field_validator, HttpUrl from typing import Optional import httpx import os import re import urllib.request import urllib.error import 数据库连接 as json_db from database_sql import get_db from models_sql import Ownership router = APIRouter() class ProxyGithubZipRequest(BaseModel): """GitHub ZIP 代理下载请求模型""" url: str = Field(..., min_length=1, description="资源URL") item_id: str = Field(..., min_length=1, max_length=100, description="资源ID") account: str = Field(..., min_length=1, max_length=50, description="用户账号") @field_validator('item_id') @classmethod def validate_item_id(cls, v: str) -> str: """验证item_id只允许[a-zA-Z0-9_-]字符""" if not re.match(r'^[a-zA-Z0-9_-]+$', v): raise ValueError('item_id只能包含字母、数字、下划线和连字符') return v @field_validator('account') @classmethod def validate_account(cls, v: str) -> str: """验证account非空且长度符合要求""" if len(v) < 1: raise ValueError('account不能为空') if len(v) > 50: raise ValueError('account长度不能超过50个字符') return v @router.post("/api/proxy_github_zip") async def proxy_github_zip(req_data: ProxyGithubZipRequest, db: Session = Depends(get_db)): """云端代理:校验所有权后,拉取 GitHub 私有库 ZIP 流透传给本地""" items_db = json_db.load_data("items.json", default_data=[]) item = next((i for i in items_db if i["id"] == req_data.item_id), None) if not item: return JSONResponse(content={"error": "资源不存在"}, status_code=404) # 1. 核心鉴权:验证用户是否购买过该工具 try: price = int(float(item.get("price", 0) or 0)) except (ValueError, TypeError): price = 0 if price > 0 and req_data.account != item.get("author"): owned = db.query(Ownership).filter(Ownership.account == req_data.account, Ownership.item_id == req_data.item_id).first() if not owned: return JSONResponse(content={"error": "🚨 拒绝访问:未找到购买记录!"}, status_code=403) # ===== ZIP 缓存优先机制(鉴权后、GitHub请求前) ===== try: versions_db = json_db.load_data("versions.json", default_data={}) item_version = versions_db.get(req_data.item_id) # 兼容新格式:{"item_id": {"hash": "...", "cached_at": "...", "zip_size": ...}} if isinstance(item_version, dict) and item_version.get("cached_at"): cache_filename = f"zips/{req_data.item_id}_{item_version['hash']}.zip" cache_url = f"https://huggingface.co/datasets/ZHIWEI666/ComfyUI-Ranking/resolve/main/{cache_filename}" cache_client = httpx.AsyncClient(follow_redirects=True) head_resp = None try: head_resp = await cache_client.head(cache_url, timeout=10, follow_redirects=True) except Exception as head_err: print(f"[ZIP缓存] HEAD验证失败: {head_err}") await cache_client.aclose() raise if head_resp.status_code == 200: print(f"[ZIP缓存] 命中: {req_data.item_id}, 大小: {item_version.get('zip_size', 'unknown')}") content_length = head_resp.headers.get('content-length', '') or str(item_version.get('zip_size', '')) try: stream_resp = await cache_client.send( cache_client.build_request("GET", cache_url), stream=True ) except Exception as stream_err: print(f"[ZIP缓存] 流请求失败: {stream_err}") await cache_client.aclose() raise cache_headers = {} if content_length: cache_headers["Content-Length"] = content_length async def cache_stream_generator(): try: if stream_resp.status_code != 200: yield b"CACHE_DOWNLOAD_FAILED" return async for chunk in stream_resp.aiter_bytes(): yield chunk except Exception: yield b"CACHE_DOWNLOAD_FAILED" finally: await stream_resp.aclose() await cache_client.aclose() return StreamingResponse(cache_stream_generator(), media_type="application/zip", headers=cache_headers) # 缓存未命中:关闭 client 后继续原有逻辑 await cache_client.aclose() except Exception as e: print(f"[ZIP缓存] 查询失败,fallback到GitHub直连: {e}") # 2. 解析 GitHub 仓库信息 repo_url = item.get("link", "").rstrip("/") if not repo_url.startswith("https://github.com/"): return JSONResponse(content={"error": "无效的仓库地址,目前仅支持 GitHub 私有库代理"}, status_code=400) repo_parts = repo_url.split("/") if len(repo_parts) < 2: return JSONResponse(content={"error": "无效的仓库地址格式"}, status_code=400) owner = repo_parts[-2] repo = repo_parts[-1].replace(".git", "") # GitHub 官方提供的打包下载 API github_zip_api = f"https://api.github.com/repos/{owner}/{repo}/zipball" # 优先读取该资源在数据库中绑定的专属创作者 Token creator_token = item.get("github_token") # 如果没填,尝试使用官方全局兜底的 PAT fallback_token = os.environ.get("GITHUB_PAT") active_token = creator_token if creator_token else fallback_token headers = { "Accept": "application/vnd.github.v3+json", "User-Agent": "ComfyUI-Ranking-SaaS" } if active_token: headers["Authorization"] = f"Bearer {active_token}" # 3. 异步请求 GitHub API 并以流形式透传回客户端 (防内存打爆) client = httpx.AsyncClient(follow_redirects=True) try: response = await client.send( client.build_request("GET", github_zip_api, headers=headers), stream=True ) # 获取 Content-Length 并透传给客户端,使本地端能计算下载进度 content_length = response.headers.get('content-length', '') resp_headers = {} if content_length: resp_headers["Content-Length"] = content_length async def stream_generator(): try: if response.status_code != 200: yield b"GITHUB_DOWNLOAD_FAILED" return async for chunk in response.aiter_bytes(): yield chunk except Exception: yield b"GITHUB_DOWNLOAD_FAILED" finally: await response.aclose() await client.aclose() return StreamingResponse(stream_generator(), media_type="application/zip", headers=resp_headers) except Exception as e: await client.aclose() return JSONResponse(content={"error": f"代理下载时发生网络异常:{str(e)}"}, status_code=500) # ========================================== # 新增:工作流/应用 (App) JSON 代理下载接口 # ========================================== class ProxyDownloadRequest(BaseModel): """工作流/应用 JSON 代理下载请求模型""" url: str = Field(..., min_length=1, description="下载URL") item_id: str = Field(..., min_length=1, max_length=100, description="资源ID") account: str = Field(..., min_length=1, max_length=50, description="用户账号") @field_validator('item_id') @classmethod def validate_item_id(cls, v: str) -> str: """验证item_id只允许[a-zA-Z0-9_-]字符""" if not re.match(r'^[a-zA-Z0-9_-]+$', v): raise ValueError('item_id只能包含字母、数字、下划线和连字符') return v @field_validator('account') @classmethod def validate_account(cls, v: str) -> str: """验证account非空且长度符合要求""" if len(v) < 1: raise ValueError('account不能为空') if len(v) > 50: raise ValueError('account长度不能超过50个字符') return v @field_validator('url') @classmethod def validate_url(cls, v: str) -> str: """验证URL格式""" if not v.startswith(('http://', 'https://')): raise ValueError('url必须是有效的HTTP或HTTPS地址') return v @router.post("/api/proxy_download") async def proxy_download(req_data: ProxyDownloadRequest, db: Session = Depends(get_db)): """云端代理:校验所有权后,拉取真实的工作流 JSON 文件并透传给本地""" items_db = json_db.load_data("items.json", default_data=[]) item = next((i for i in items_db if i["id"] == req_data.item_id), None) if not item: return JSONResponse(content={"error": "云端记录中找不到该资源"}, status_code=404) # 1. 核心鉴权:验证用户是否购买/获取过该工作流 price = int(item.get("price", 0)) if price > 0 and req_data.account != item.get("author"): owned = db.query(Ownership).filter(Ownership.account == req_data.account, Ownership.item_id == req_data.item_id).first() if not owned: return JSONResponse(content={"error": "您尚未获取该工作流,请先在社区页面点击获取"}, status_code=403) target_url = req_data.url # 🚀 核心修复:从环境变量提取 Hugging Face Token,并组装 Authorization 请求头 hf_token = os.environ.get("HF_TOKEN") headers = {"User-Agent": "ComfyUI-Ranking-SaaS"} if hf_token and "huggingface.co" in target_url: headers["Authorization"] = f"Bearer {hf_token}" # 🚀 核心修复:使用异步 httpx 替代同步 urllib,避免阻塞事件循环 # ⚠️ 安全警告:仅在网络环境导致证书验证失败时,通过环境变量临时关闭 verify_ssl = os.environ.get("DISABLE_SSL_VERIFY", "").lower() not in ("1", "true") if not verify_ssl: print("⚠️ SSL证书验证已关闭,请仅在调试环境使用") try: async with httpx.AsyncClient(follow_redirects=True, verify=verify_ssl, timeout=120.0) as client: print(f"🔍 开始下载资源 [{req_data.item_id}]") print(f"🔗 目标地址:{target_url[:80]}...") response = await client.get(target_url, headers=headers) if response.status_code != 200: print(f"❌ 源文件拉取失败 [HTTP {response.status_code}]") return JSONResponse( content={"error": f"源文件拉取失败,HTTP 状态码:{response.status_code}"}, status_code=response.status_code ) content = response.content print(f"✅ 成功下载资源 [{req_data.item_id}], 大小:{len(content)} bytes") return Response( content=content, media_type="application/json", status_code=200, headers={"Content-Disposition": f"attachment; filename={req_data.item_id}.json"} ) except httpx.TimeoutException as e: print(f"❌ 下载超时:{str(e)}") return JSONResponse(content={"error": "下载超时,请稍后重试"}, status_code=504) except Exception as e: import traceback print(f"❌ 代理下载异常:{str(e)}") print(traceback.format_exc()) return JSONResponse(content={"error": f"代理下载时发生网络异常:{str(e)}"}, status_code=500)