ZHIWEI666 commited on
Commit
a90ed05
·
verified ·
1 Parent(s): ab01d5b

Upload router_proxy.py

Browse files
Files changed (1) hide show
  1. router_proxy.py +137 -120
router_proxy.py CHANGED
@@ -1,120 +1,137 @@
1
- # router_proxy.py
2
- from fastapi import APIRouter, Depends, HTTPException
3
- from fastapi.responses import StreamingResponse, JSONResponse, Response
4
- from sqlalchemy.orm import Session
5
- from pydantic import BaseModel
6
- import httpx
7
- import os
8
- import urllib.request
9
- import urllib.error
10
- import 数据库连接 as json_db
11
- from database_sql import get_db
12
- from models_sql import Ownership
13
-
14
- router = APIRouter()
15
-
16
- class ProxyGithubZipRequest(BaseModel):
17
- url: str
18
- item_id: str
19
- account: str
20
-
21
- @router.post("/api/proxy_github_zip")
22
- async def proxy_github_zip(req_data: ProxyGithubZipRequest, db: Session = Depends(get_db)):
23
- """云端代理:校验所有权后,拉取 GitHub 私有库 ZIP 流透传给本地"""
24
- items_db = json_db.load_data("items.json", default_data=[])
25
- item = next((i for i in items_db if i["id"] == req_data.item_id), None)
26
- if not item: return JSONResponse(content={"error": "资源不存在"}, status_code=404)
27
-
28
- # 1. 核心鉴权:验证用户是否购买过该工具
29
- price = int(item.get("price", 0))
30
- if price > 0 and req_data.account != item.get("author"):
31
- owned = db.query(Ownership).filter(Ownership.account == req_data.account, Ownership.item_id == req_data.item_id).first()
32
- if not owned:
33
- return JSONResponse(content={"error": "🚨 拒绝访问:未找到购买记录!"}, status_code=403)
34
-
35
- # 2. 解析 GitHub 仓库信息
36
- repo_url = item.get("link", "").rstrip("/")
37
- if not repo_url.startswith("https://github.com/"):
38
- return JSONResponse(content={"error": "无效的仓库地址,目前仅支持 GitHub 私有库代理"}, status_code=400)
39
-
40
- repo_parts = repo_url.split("/")
41
- if len(repo_parts) < 2: return JSONResponse(content={"error": "无效的仓库地址格式"}, status_code=400)
42
- owner = repo_parts[-2]
43
- repo = repo_parts[-1].replace(".git", "")
44
-
45
- # GitHub 官方提供打包下载 API
46
- github_zip_api = f"https://api.github.com/repos/{owner}/{repo}/zipball"
47
-
48
- # 优先读取该资源在数据库中绑定的专属创作者 Token
49
- creator_token = item.get("github_token")
50
- # 如果没填,尝试使用官方全局兜底的 PAT
51
- fallback_token = os.environ.get("GITHUB_PAT")
52
-
53
- active_token = creator_token if creator_token else fallback_token
54
-
55
- headers = {
56
- "Accept": "application/vnd.github.v3+json",
57
- "User-Agent": "ComfyUI-Ranking-SaaS"
58
- }
59
- if active_token:
60
- headers["Authorization"] = f"Bearer {active_token}"
61
-
62
- # 3. 异步请求 GitHub API 并以流形式透传回客户端 (防内存打爆)
63
- async def stream_generator():
64
- async with httpx.AsyncClient(follow_redirects=True) as client:
65
- try:
66
- async with client.stream("GET", github_zip_api, headers=headers, timeout=120.0) as response:
67
- if response.status_code != 200:
68
- yield b"GITHUB_DOWNLOAD_FAILED"
69
- return
70
- async for chunk in response.aiter_bytes():
71
- yield chunk
72
- except Exception as e:
73
- yield b"GITHUB_DOWNLOAD_FAILED"
74
-
75
- return StreamingResponse(stream_generator(), media_type="application/zip")
76
-
77
- # ==========================================
78
- # 新增:工作流/应用 (App) JSON 代理下载接口
79
- # ==========================================
80
- class ProxyDownloadRequest(BaseModel):
81
- url: str
82
- item_id: str
83
- account: str
84
-
85
- @router.post("/api/proxy_download")
86
- async def proxy_download(req_data: ProxyDownloadRequest, db: Session = Depends(get_db)):
87
- """云端代理:校验所有权后,拉取真实的工作流 JSON 文件并透传给本地"""
88
- items_db = json_db.load_data("items.json", default_data=[])
89
- item = next((i for i in items_db if i["id"] == req_data.item_id), None)
90
-
91
- if not item:
92
- return JSONResponse(content={"error": "云端记录中找不到该资源"}, status_code=404)
93
-
94
- # 1. 核心鉴权:验证用户是否购买/获取过该工作流
95
- price = int(item.get("price", 0))
96
- if price > 0 and req_data.account != item.get("author"):
97
- owned = db.query(Ownership).filter(Ownership.account == req_data.account, Ownership.item_id == req_data.item_id).first()
98
- if not owned:
99
- return JSONResponse(content={"error": "您尚未获取该工作流,请先在社区页面点击获取"}, status_code=403)
100
-
101
- target_url = req_data.url
102
-
103
- # 🚀 核心修复:从环境变量提取 Hugging Face Token,并组装 Authorization 请求头
104
- hf_token = os.environ.get("HF_TOKEN")
105
-
106
- headers = {"User-Agent": "ComfyUI-Ranking-SaaS"}
107
- if hf_token and "huggingface.co" in target_url:
108
- headers["Authorization"] = f"Bearer {hf_token}"
109
-
110
- # 2. 使用 urllib 拉取真实 JSON 数据(保持原始中文文件名,不自动编码)
111
- try:
112
- req = urllib.request.Request(target_url, headers=headers)
113
- with urllib.request.urlopen(req, timeout=30) as resp:
114
- content = resp.read()
115
- return Response(content=content, media_type="application/json")
116
-
117
- except urllib.error.HTTPError as e:
118
- return JSONResponse(content={"error": f"源文件拉取失败,HTTP状态码: {e.code}"}, status_code=e.code)
119
- except Exception as e:
120
- return JSONResponse(content={"error": f"代理下载时发生网络异常: {str(e)}"}, status_code=500)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # router_proxy.py
2
+ from fastapi import APIRouter, Depends, HTTPException
3
+ from fastapi.responses import StreamingResponse, JSONResponse, Response
4
+ from sqlalchemy.orm import Session
5
+ from pydantic import BaseModel
6
+ import httpx
7
+ import os
8
+ import urllib.request
9
+ import urllib.error
10
+ import 数据库连接 as json_db
11
+ from database_sql import get_db
12
+ from models_sql import Ownership
13
+
14
+ router = APIRouter()
15
+
16
+ class ProxyGithubZipRequest(BaseModel):
17
+ url: str
18
+ item_id: str
19
+ account: str
20
+
21
+ @router.post("/api/proxy_github_zip")
22
+ async def proxy_github_zip(req_data: ProxyGithubZipRequest, db: Session = Depends(get_db)):
23
+ """云端代理:校验所有权后,拉取 GitHub 私有库 ZIP 流透传给本地"""
24
+ items_db = json_db.load_data("items.json", default_data=[])
25
+ item = next((i for i in items_db if i["id"] == req_data.item_id), None)
26
+ if not item: return JSONResponse(content={"error": "资源不存在"}, status_code=404)
27
+
28
+ # 1. 核心鉴权:验证用户是否购买过该工具
29
+ price = int(item.get("price", 0))
30
+ if price > 0 and req_data.account != item.get("author"):
31
+ owned = db.query(Ownership).filter(
32
+ Ownership.account == req_data.account,
33
+ Ownership.item_id == req_data.item_id,
34
+ Ownership.is_refunded == False
35
+ ).first()
36
+ if not owned:
37
+ return JSONResponse(content={"error": "🚨 拒绝访问:未找到购买记录!"}, status_code=403)
38
+
39
+ # 2. 解析 GitHub 仓库信息
40
+ repo_url = item.get("link", "").rstrip("/")
41
+ if not repo_url.startswith("https://github.com/"):
42
+ return JSONResponse(content={"error": "无效的仓库地址,目前仅支持 GitHub 私有库代理"}, status_code=400)
43
+
44
+ repo_parts = repo_url.split("/")
45
+ if len(repo_parts) < 2: return JSONResponse(content={"error": "无效仓库地址格式"}, status_code=400)
46
+ owner = repo_parts[-2]
47
+ repo = repo_parts[-1].replace(".git", "")
48
+
49
+ # GitHub 官方提供的打包下载 API
50
+ github_zip_api = f"https://api.github.com/repos/{owner}/{repo}/zipball"
51
+
52
+ # 优先读取该资源在数据库中绑定的专属创作者 Token
53
+ creator_token = item.get("github_token")
54
+ # 如果没填,尝试使用官方全局兜底的 PAT
55
+ fallback_token = os.environ.get("GITHUB_PAT")
56
+
57
+ active_token = creator_token if creator_token else fallback_token
58
+
59
+ headers = {
60
+ "Accept": "application/vnd.github.v3+json",
61
+ "User-Agent": "ComfyUI-Ranking-SaaS"
62
+ }
63
+ if active_token:
64
+ headers["Authorization"] = f"Bearer {active_token}"
65
+
66
+ # 3. 异步请求 GitHub API 并以流形式透传回客户端 (防内存打爆)
67
+ async def stream_generator():
68
+ async with httpx.AsyncClient(follow_redirects=True) as client:
69
+ try:
70
+ async with client.stream("GET", github_zip_api, headers=headers, timeout=120.0) as response:
71
+ if response.status_code != 200:
72
+ yield b"GITHUB_DOWNLOAD_FAILED"
73
+ return
74
+ async for chunk in response.aiter_bytes():
75
+ yield chunk
76
+ except Exception as e:
77
+ yield b"GITHUB_DOWNLOAD_FAILED"
78
+
79
+ return StreamingResponse(stream_generator(), media_type="application/zip")
80
+
81
+ # ==========================================
82
+ # 新增:工作流/应用 (App) JSON 代理下载接口
83
+ # ==========================================
84
+ class ProxyDownloadRequest(BaseModel):
85
+ url: str
86
+ item_id: str
87
+ account: str
88
+
89
+ @router.post("/api/proxy_download")
90
+ async def proxy_download(req_data: ProxyDownloadRequest, db: Session = Depends(get_db)):
91
+ """云端代理:校验所有权后,拉取真实的工作流 JSON 文件并透传给本地"""
92
+ try:
93
+ items_db = json_db.load_data("items.json", default_data=[])
94
+ item = next((i for i in items_db if i["id"] == req_data.item_id), None)
95
+
96
+ if not item:
97
+ return JSONResponse(content={"error": "云端记录中找不到该资源"}, status_code=404)
98
+
99
+ # 1. 核心鉴权:验证用户是否购买/获取该工作流
100
+ price = int(item.get("price", 0))
101
+ if price > 0 and req_data.account != item.get("author"):
102
+ try:
103
+ owned = db.query(Ownership).filter(
104
+ Ownership.account == req_data.account,
105
+ Ownership.item_id == req_data.item_id,
106
+ Ownership.is_refunded == False
107
+ ).first()
108
+ if not owned:
109
+ return JSONResponse(content={"error": "您尚未获取该工作流,请先在社区页面点击获取"}, status_code=403)
110
+ except Exception as e:
111
+ print(f"[proxy_download] 数据库查询失败: {e}")
112
+ return JSONResponse(content={"error": f"鉴权查询失败: {str(e)}"}, status_code=500)
113
+
114
+ target_url = req_data.url
115
+
116
+ # 🚀 核心修复:从环境变量提取 Hugging Face Token,并组装 Authorization 请求头
117
+ hf_token = os.environ.get("HF_TOKEN")
118
+
119
+ headers = {"User-Agent": "ComfyUI-Ranking-SaaS"}
120
+ if hf_token and "huggingface.co" in target_url:
121
+ headers["Authorization"] = f"Bearer {hf_token}"
122
+
123
+ # 2. 使用 urllib 拉取真实 JSON 数据(保持原始中文文件名,不自动编码)
124
+ try:
125
+ req = urllib.request.Request(target_url, headers=headers)
126
+ with urllib.request.urlopen(req, timeout=30) as resp:
127
+ content = resp.read()
128
+ return Response(content=content, media_type="application/json")
129
+
130
+ except urllib.error.HTTPError as e:
131
+ return JSONResponse(content={"error": f"源文件拉取失败,HTTP状态码: {e.code}"}, status_code=e.code)
132
+ except Exception as e:
133
+ return JSONResponse(content={"error": f"代理下载时发生网络异常: {str(e)}"}, status_code=500)
134
+
135
+ except Exception as e:
136
+ print(f"[proxy_download] 未处理的异常: {e}")
137
+ return JSONResponse(content={"error": f"服务器内部错误: {str(e)}"}, status_code=500)