ComfyUI-Ranking-API / router_proxy.py
ZHIWEI666's picture
Upload 19 files
f68778c verified
raw
history blame
5.54 kB
# router_proxy.py
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse, Response
from sqlalchemy.orm import Session
from pydantic import BaseModel
import httpx
import os
import urllib.request
import urllib.error
import 数据库连接 as json_db
from database_sql import get_db
from models_sql import Ownership
router = APIRouter()
class ProxyGithubZipRequest(BaseModel):
url: str
item_id: str
account: str
@router.post("/api/proxy_github_zip")
async def proxy_github_zip(req_data: ProxyGithubZipRequest, db: Session = Depends(get_db)):
"""云端代理:校验所有权后,拉取 GitHub 私有库 ZIP 流透传给本地"""
items_db = json_db.load_data("items.json", default_data=[])
item = next((i for i in items_db if i["id"] == req_data.item_id), None)
if not item: return JSONResponse(content={"error": "资源不存在"}, status_code=404)
# 1. 核心鉴权:验证用户是否购买过该工具
price = int(item.get("price", 0))
if price > 0 and req_data.account != item.get("author"):
owned = db.query(Ownership).filter(Ownership.account == req_data.account, Ownership.item_id == req_data.item_id).first()
if not owned:
return JSONResponse(content={"error": "🚨 拒绝访问:未找到购买记录!"}, status_code=403)
# 2. 解析 GitHub 仓库信息
repo_url = item.get("link", "").rstrip("/")
if not repo_url.startswith("https://github.com/"):
return JSONResponse(content={"error": "无效的仓库地址,目前仅支持 GitHub 私有库代理"}, status_code=400)
repo_parts = repo_url.split("/")
if len(repo_parts) < 2: return JSONResponse(content={"error": "无效的仓库地址格式"}, status_code=400)
owner = repo_parts[-2]
repo = repo_parts[-1].replace(".git", "")
# GitHub 官方提供的打包下载 API
github_zip_api = f"https://api.github.com/repos/{owner}/{repo}/zipball"
# 优先读取该资源在数据库中绑定的专属创作者 Token
creator_token = item.get("github_token")
# 如果没填,尝试使用官方全局兜底的 PAT
fallback_token = os.environ.get("GITHUB_PAT")
active_token = creator_token if creator_token else fallback_token
headers = {
"Accept": "application/vnd.github.v3+json",
"User-Agent": "ComfyUI-Ranking-SaaS"
}
if active_token:
headers["Authorization"] = f"Bearer {active_token}"
# 3. 异步请求 GitHub API 并以流形式透传回客户端 (防内存打爆)
async def stream_generator():
async with httpx.AsyncClient(follow_redirects=True) as client:
try:
async with client.stream("GET", github_zip_api, headers=headers, timeout=120.0) as response:
if response.status_code != 200:
yield b"GITHUB_DOWNLOAD_FAILED"
return
async for chunk in response.aiter_bytes():
yield chunk
except Exception as e:
yield b"GITHUB_DOWNLOAD_FAILED"
return StreamingResponse(stream_generator(), media_type="application/zip")
# ==========================================
# 新增:工作流/应用 (App) JSON 代理下载接口
# ==========================================
class ProxyDownloadRequest(BaseModel):
url: str
item_id: str
account: str
@router.post("/api/proxy_download")
async def proxy_download(req_data: ProxyDownloadRequest, db: Session = Depends(get_db)):
"""云端代理:校验所有权后,拉取真实的工作流 JSON 文件并透传给本地"""
items_db = json_db.load_data("items.json", default_data=[])
item = next((i for i in items_db if i["id"] == req_data.item_id), None)
if not item:
return JSONResponse(content={"error": "云端记录中找不到该资源"}, status_code=404)
# 1. 核心鉴权:验证用户是否购买/获取过该工作流
price = int(item.get("price", 0))
if price > 0 and req_data.account != item.get("author"):
owned = db.query(Ownership).filter(Ownership.account == req_data.account, Ownership.item_id == req_data.item_id).first()
if not owned:
return JSONResponse(content={"error": "您尚未获取该工作流,请先在社区页面点击获取"}, status_code=403)
target_url = req_data.url
# 🚀 核心修复:从环境变量提取 Hugging Face Token,并组装 Authorization 请求头
hf_token = os.environ.get("HF_TOKEN")
headers = {"User-Agent": "ComfyUI-Ranking-SaaS"}
if hf_token and "huggingface.co" in target_url:
headers["Authorization"] = f"Bearer {hf_token}"
# 2. 使用 urllib 拉取真实 JSON 数据(保持原始中文文件名,不自动编码)
try:
req = urllib.request.Request(target_url, headers=headers)
with urllib.request.urlopen(req, timeout=30) as resp:
content = resp.read()
return Response(content=content, media_type="application/json")
except urllib.error.HTTPError as e:
return JSONResponse(content={"error": f"源文件拉取失败,HTTP状态码: {e.code}"}, status_code=e.code)
except Exception as e:
return JSONResponse(content={"error": f"代理下载时发生网络异常: {str(e)}"}, status_code=500)