ComfyUI-Ranking-API / router_proxy.py
ZHIWEI666's picture
Upload 3 files
25f34c9 verified
# router_proxy.py
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse, Response
from sqlalchemy.orm import Session
from pydantic import BaseModel, Field, field_validator, HttpUrl
from typing import Optional
import httpx
import os
import re
import urllib.request
import urllib.error
import 数据库连接 as json_db
from database_sql import get_db
from models_sql import Ownership
router = APIRouter()
class ProxyGithubZipRequest(BaseModel):
"""GitHub ZIP 代理下载请求模型"""
url: str = Field(..., min_length=1, description="资源URL")
item_id: str = Field(..., min_length=1, max_length=100, description="资源ID")
account: str = Field(..., min_length=1, max_length=50, description="用户账号")
@field_validator('item_id')
@classmethod
def validate_item_id(cls, v: str) -> str:
"""验证item_id只允许[a-zA-Z0-9_-]字符"""
if not re.match(r'^[a-zA-Z0-9_-]+$', v):
raise ValueError('item_id只能包含字母、数字、下划线和连字符')
return v
@field_validator('account')
@classmethod
def validate_account(cls, v: str) -> str:
"""验证account非空且长度符合要求"""
if len(v) < 1:
raise ValueError('account不能为空')
if len(v) > 50:
raise ValueError('account长度不能超过50个字符')
return v
@router.post("/api/proxy_github_zip")
async def proxy_github_zip(req_data: ProxyGithubZipRequest, db: Session = Depends(get_db)):
"""云端代理:校验所有权后,拉取 GitHub 私有库 ZIP 流透传给本地"""
items_db = json_db.load_data("items.json", default_data=[])
item = next((i for i in items_db if i["id"] == req_data.item_id), None)
if not item: return JSONResponse(content={"error": "资源不存在"}, status_code=404)
# 1. 核心鉴权:验证用户是否购买过该工具
try:
price = int(float(item.get("price", 0) or 0))
except (ValueError, TypeError):
price = 0
if price > 0 and req_data.account != item.get("author"):
owned = db.query(Ownership).filter(Ownership.account == req_data.account, Ownership.item_id == req_data.item_id).first()
if not owned:
return JSONResponse(content={"error": "🚨 拒绝访问:未找到购买记录!"}, status_code=403)
# ===== ZIP 缓存优先机制(鉴权后、GitHub请求前) =====
try:
versions_db = json_db.load_data("versions.json", default_data={})
item_version = versions_db.get(req_data.item_id)
# 兼容新格式:{"item_id": {"hash": "...", "cached_at": "...", "zip_size": ...}}
if isinstance(item_version, dict) and item_version.get("cached_at"):
cache_filename = f"zips/{req_data.item_id}_{item_version['hash']}.zip"
cache_url = f"https://huggingface.co/datasets/ZHIWEI666/ComfyUI-Ranking/resolve/main/{cache_filename}"
cache_client = httpx.AsyncClient(follow_redirects=True)
head_resp = None
try:
head_resp = await cache_client.head(cache_url, timeout=10, follow_redirects=True)
except Exception as head_err:
print(f"[ZIP缓存] HEAD验证失败: {head_err}")
await cache_client.aclose()
raise
if head_resp.status_code == 200:
print(f"[ZIP缓存] 命中: {req_data.item_id}, 大小: {item_version.get('zip_size', 'unknown')}")
content_length = head_resp.headers.get('content-length', '') or str(item_version.get('zip_size', ''))
try:
stream_resp = await cache_client.send(
cache_client.build_request("GET", cache_url),
stream=True
)
except Exception as stream_err:
print(f"[ZIP缓存] 流请求失败: {stream_err}")
await cache_client.aclose()
raise
cache_headers = {}
if content_length:
cache_headers["Content-Length"] = content_length
async def cache_stream_generator():
try:
if stream_resp.status_code != 200:
yield b"CACHE_DOWNLOAD_FAILED"
return
async for chunk in stream_resp.aiter_bytes():
yield chunk
except Exception:
yield b"CACHE_DOWNLOAD_FAILED"
finally:
await stream_resp.aclose()
await cache_client.aclose()
return StreamingResponse(cache_stream_generator(), media_type="application/zip", headers=cache_headers)
# 缓存未命中:关闭 client 后继续原有逻辑
await cache_client.aclose()
except Exception as e:
print(f"[ZIP缓存] 查询失败,fallback到GitHub直连: {e}")
# 2. 解析 GitHub 仓库信息
repo_url = item.get("link", "").rstrip("/")
if not repo_url.startswith("https://github.com/"):
return JSONResponse(content={"error": "无效的仓库地址,目前仅支持 GitHub 私有库代理"}, status_code=400)
repo_parts = repo_url.split("/")
if len(repo_parts) < 2: return JSONResponse(content={"error": "无效的仓库地址格式"}, status_code=400)
owner = repo_parts[-2]
repo = repo_parts[-1].replace(".git", "")
# GitHub 官方提供的打包下载 API
github_zip_api = f"https://api.github.com/repos/{owner}/{repo}/zipball"
# 优先读取该资源在数据库中绑定的专属创作者 Token
creator_token = item.get("github_token")
# 如果没填,尝试使用官方全局兜底的 PAT
fallback_token = os.environ.get("GITHUB_PAT")
active_token = creator_token if creator_token else fallback_token
headers = {
"Accept": "application/vnd.github.v3+json",
"User-Agent": "ComfyUI-Ranking-SaaS"
}
if active_token:
headers["Authorization"] = f"Bearer {active_token}"
# 3. 异步请求 GitHub API 并以流形式透传回客户端 (防内存打爆)
client = httpx.AsyncClient(follow_redirects=True)
try:
response = await client.send(
client.build_request("GET", github_zip_api, headers=headers),
stream=True
)
# 获取 Content-Length 并透传给客户端,使本地端能计算下载进度
content_length = response.headers.get('content-length', '')
resp_headers = {}
if content_length:
resp_headers["Content-Length"] = content_length
async def stream_generator():
try:
if response.status_code != 200:
yield b"GITHUB_DOWNLOAD_FAILED"
return
async for chunk in response.aiter_bytes():
yield chunk
except Exception:
yield b"GITHUB_DOWNLOAD_FAILED"
finally:
await response.aclose()
await client.aclose()
return StreamingResponse(stream_generator(), media_type="application/zip", headers=resp_headers)
except Exception as e:
await client.aclose()
return JSONResponse(content={"error": f"代理下载时发生网络异常:{str(e)}"}, status_code=500)
# ==========================================
# 新增:工作流/应用 (App) JSON 代理下载接口
# ==========================================
class ProxyDownloadRequest(BaseModel):
"""工作流/应用 JSON 代理下载请求模型"""
url: str = Field(..., min_length=1, description="下载URL")
item_id: str = Field(..., min_length=1, max_length=100, description="资源ID")
account: str = Field(..., min_length=1, max_length=50, description="用户账号")
@field_validator('item_id')
@classmethod
def validate_item_id(cls, v: str) -> str:
"""验证item_id只允许[a-zA-Z0-9_-]字符"""
if not re.match(r'^[a-zA-Z0-9_-]+$', v):
raise ValueError('item_id只能包含字母、数字、下划线和连字符')
return v
@field_validator('account')
@classmethod
def validate_account(cls, v: str) -> str:
"""验证account非空且长度符合要求"""
if len(v) < 1:
raise ValueError('account不能为空')
if len(v) > 50:
raise ValueError('account长度不能超过50个字符')
return v
@field_validator('url')
@classmethod
def validate_url(cls, v: str) -> str:
"""验证URL格式"""
if not v.startswith(('http://', 'https://')):
raise ValueError('url必须是有效的HTTP或HTTPS地址')
return v
@router.post("/api/proxy_download")
async def proxy_download(req_data: ProxyDownloadRequest, db: Session = Depends(get_db)):
"""云端代理:校验所有权后,拉取真实的工作流 JSON 文件并透传给本地"""
items_db = json_db.load_data("items.json", default_data=[])
item = next((i for i in items_db if i["id"] == req_data.item_id), None)
if not item:
return JSONResponse(content={"error": "云端记录中找不到该资源"}, status_code=404)
# 1. 核心鉴权:验证用户是否购买/获取过该工作流
price = int(item.get("price", 0))
if price > 0 and req_data.account != item.get("author"):
owned = db.query(Ownership).filter(Ownership.account == req_data.account, Ownership.item_id == req_data.item_id).first()
if not owned:
return JSONResponse(content={"error": "您尚未获取该工作流,请先在社区页面点击获取"}, status_code=403)
target_url = req_data.url
# 🚀 核心修复:从环境变量提取 Hugging Face Token,并组装 Authorization 请求头
hf_token = os.environ.get("HF_TOKEN")
headers = {"User-Agent": "ComfyUI-Ranking-SaaS"}
if hf_token and "huggingface.co" in target_url:
headers["Authorization"] = f"Bearer {hf_token}"
# 🚀 核心修复:使用异步 httpx 替代同步 urllib,避免阻塞事件循环
# ⚠️ 安全警告:仅在网络环境导致证书验证失败时,通过环境变量临时关闭
verify_ssl = os.environ.get("DISABLE_SSL_VERIFY", "").lower() not in ("1", "true")
if not verify_ssl:
print("⚠️ SSL证书验证已关闭,请仅在调试环境使用")
try:
async with httpx.AsyncClient(follow_redirects=True, verify=verify_ssl, timeout=120.0) as client:
print(f"🔍 开始下载资源 [{req_data.item_id}]")
print(f"🔗 目标地址:{target_url[:80]}...")
response = await client.get(target_url, headers=headers)
if response.status_code != 200:
print(f"❌ 源文件拉取失败 [HTTP {response.status_code}]")
return JSONResponse(
content={"error": f"源文件拉取失败,HTTP 状态码:{response.status_code}"},
status_code=response.status_code
)
content = response.content
print(f"✅ 成功下载资源 [{req_data.item_id}], 大小:{len(content)} bytes")
return Response(
content=content,
media_type="application/json",
status_code=200,
headers={"Content-Disposition": f"attachment; filename={req_data.item_id}.json"}
)
except httpx.TimeoutException as e:
print(f"❌ 下载超时:{str(e)}")
return JSONResponse(content={"error": "下载超时,请稍后重试"}, status_code=504)
except Exception as e:
import traceback
print(f"❌ 代理下载异常:{str(e)}")
print(traceback.format_exc())
return JSONResponse(content={"error": f"代理下载时发生网络异常:{str(e)}"}, status_code=500)