File size: 12,573 Bytes
32b7155
 
 
 
8f9d15a
 
32b7155
 
8f9d15a
32b7155
 
 
 
 
 
 
 
8f9d15a
32b7155
8f9d15a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32b7155
 
 
 
 
 
 
 
 
7e0f067
 
 
 
32b7155
 
 
 
 
25f34c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32b7155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21b27b1
 
 
 
 
 
 
 
 
 
 
 
 
 
32b7155
21b27b1
 
 
 
 
 
32b7155
21b27b1
 
 
32b7155
21b27b1
 
 
 
32b7155
 
 
 
 
8f9d15a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32b7155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b1406b
8f9d15a
 
 
 
 
32b7155
8f9d15a
3b1406b
 
 
 
 
 
 
 
 
 
 
 
 
32b7155
3b1406b
32b7155
 
 
 
 
 
 
3b1406b
 
 
32b7155
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
# router_proxy.py
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse, Response
from sqlalchemy.orm import Session
from pydantic import BaseModel, Field, field_validator, HttpUrl
from typing import Optional
import httpx
import os
import re
import urllib.request
import urllib.error
import 数据库连接 as json_db
from database_sql import get_db
from models_sql import Ownership

router = APIRouter()


class ProxyGithubZipRequest(BaseModel):
    """GitHub ZIP 代理下载请求模型"""
    url: str = Field(..., min_length=1, description="资源URL")
    item_id: str = Field(..., min_length=1, max_length=100, description="资源ID")
    account: str = Field(..., min_length=1, max_length=50, description="用户账号")
    
    @field_validator('item_id')
    @classmethod
    def validate_item_id(cls, v: str) -> str:
        """验证item_id只允许[a-zA-Z0-9_-]字符"""
        if not re.match(r'^[a-zA-Z0-9_-]+$', v):
            raise ValueError('item_id只能包含字母、数字、下划线和连字符')
        return v
    
    @field_validator('account')
    @classmethod
    def validate_account(cls, v: str) -> str:
        """验证account非空且长度符合要求"""
        if len(v) < 1:
            raise ValueError('account不能为空')
        if len(v) > 50:
            raise ValueError('account长度不能超过50个字符')
        return v

@router.post("/api/proxy_github_zip")
async def proxy_github_zip(req_data: ProxyGithubZipRequest, db: Session = Depends(get_db)):
    """云端代理:校验所有权后,拉取 GitHub 私有库 ZIP 流透传给本地"""
    items_db = json_db.load_data("items.json", default_data=[])
    item = next((i for i in items_db if i["id"] == req_data.item_id), None)
    if not item: return JSONResponse(content={"error": "资源不存在"}, status_code=404)
    
    # 1. 核心鉴权:验证用户是否购买过该工具
    try:
        price = int(float(item.get("price", 0) or 0))
    except (ValueError, TypeError):
        price = 0
    if price > 0 and req_data.account != item.get("author"):
        owned = db.query(Ownership).filter(Ownership.account == req_data.account, Ownership.item_id == req_data.item_id).first()
        if not owned:
            return JSONResponse(content={"error": "🚨 拒绝访问:未找到购买记录!"}, status_code=403)

    # ===== ZIP 缓存优先机制(鉴权后、GitHub请求前) =====
    try:
        versions_db = json_db.load_data("versions.json", default_data={})
        item_version = versions_db.get(req_data.item_id)
        
        # 兼容新格式:{"item_id": {"hash": "...", "cached_at": "...", "zip_size": ...}}
        if isinstance(item_version, dict) and item_version.get("cached_at"):
            cache_filename = f"zips/{req_data.item_id}_{item_version['hash']}.zip"
            cache_url = f"https://huggingface.co/datasets/ZHIWEI666/ComfyUI-Ranking/resolve/main/{cache_filename}"
            
            cache_client = httpx.AsyncClient(follow_redirects=True)
            head_resp = None
            try:
                head_resp = await cache_client.head(cache_url, timeout=10, follow_redirects=True)
            except Exception as head_err:
                print(f"[ZIP缓存] HEAD验证失败: {head_err}")
                await cache_client.aclose()
                raise
            
            if head_resp.status_code == 200:
                print(f"[ZIP缓存] 命中: {req_data.item_id}, 大小: {item_version.get('zip_size', 'unknown')}")
                
                content_length = head_resp.headers.get('content-length', '') or str(item_version.get('zip_size', ''))
                
                try:
                    stream_resp = await cache_client.send(
                        cache_client.build_request("GET", cache_url),
                        stream=True
                    )
                except Exception as stream_err:
                    print(f"[ZIP缓存] 流请求失败: {stream_err}")
                    await cache_client.aclose()
                    raise
                
                cache_headers = {}
                if content_length:
                    cache_headers["Content-Length"] = content_length
                
                async def cache_stream_generator():
                    try:
                        if stream_resp.status_code != 200:
                            yield b"CACHE_DOWNLOAD_FAILED"
                            return
                        async for chunk in stream_resp.aiter_bytes():
                            yield chunk
                    except Exception:
                        yield b"CACHE_DOWNLOAD_FAILED"
                    finally:
                        await stream_resp.aclose()
                        await cache_client.aclose()
                
                return StreamingResponse(cache_stream_generator(), media_type="application/zip", headers=cache_headers)
            
            # 缓存未命中:关闭 client 后继续原有逻辑
            await cache_client.aclose()
    except Exception as e:
        print(f"[ZIP缓存] 查询失败,fallback到GitHub直连: {e}")

    # 2. 解析 GitHub 仓库信息
    repo_url = item.get("link", "").rstrip("/")
    if not repo_url.startswith("https://github.com/"):
        return JSONResponse(content={"error": "无效的仓库地址,目前仅支持 GitHub 私有库代理"}, status_code=400)

    repo_parts = repo_url.split("/")
    if len(repo_parts) < 2: return JSONResponse(content={"error": "无效的仓库地址格式"}, status_code=400)
    owner = repo_parts[-2]
    repo = repo_parts[-1].replace(".git", "")
    
    # GitHub 官方提供的打包下载 API
    github_zip_api = f"https://api.github.com/repos/{owner}/{repo}/zipball"
    
    # 优先读取该资源在数据库中绑定的专属创作者 Token
    creator_token = item.get("github_token")
    # 如果没填,尝试使用官方全局兜底的 PAT
    fallback_token = os.environ.get("GITHUB_PAT") 
    
    active_token = creator_token if creator_token else fallback_token

    headers = {
        "Accept": "application/vnd.github.v3+json",
        "User-Agent": "ComfyUI-Ranking-SaaS"
    }
    if active_token:
        headers["Authorization"] = f"Bearer {active_token}"

    # 3. 异步请求 GitHub API 并以流形式透传回客户端 (防内存打爆)
    client = httpx.AsyncClient(follow_redirects=True)
    try:
        response = await client.send(
            client.build_request("GET", github_zip_api, headers=headers),
            stream=True
        )

        # 获取 Content-Length 并透传给客户端,使本地端能计算下载进度
        content_length = response.headers.get('content-length', '')
        resp_headers = {}
        if content_length:
            resp_headers["Content-Length"] = content_length

        async def stream_generator():
            try:
                if response.status_code != 200:
                    yield b"GITHUB_DOWNLOAD_FAILED"
                    return
                async for chunk in response.aiter_bytes():
                    yield chunk
            except Exception:
                yield b"GITHUB_DOWNLOAD_FAILED"
            finally:
                await response.aclose()
                await client.aclose()

        return StreamingResponse(stream_generator(), media_type="application/zip", headers=resp_headers)
    except Exception as e:
        await client.aclose()
        return JSONResponse(content={"error": f"代理下载时发生网络异常:{str(e)}"}, status_code=500)

# ==========================================
# 新增:工作流/应用 (App) JSON 代理下载接口
# ==========================================
class ProxyDownloadRequest(BaseModel):
    """工作流/应用 JSON 代理下载请求模型"""
    url: str = Field(..., min_length=1, description="下载URL")
    item_id: str = Field(..., min_length=1, max_length=100, description="资源ID")
    account: str = Field(..., min_length=1, max_length=50, description="用户账号")
    
    @field_validator('item_id')
    @classmethod
    def validate_item_id(cls, v: str) -> str:
        """验证item_id只允许[a-zA-Z0-9_-]字符"""
        if not re.match(r'^[a-zA-Z0-9_-]+$', v):
            raise ValueError('item_id只能包含字母、数字、下划线和连字符')
        return v
    
    @field_validator('account')
    @classmethod
    def validate_account(cls, v: str) -> str:
        """验证account非空且长度符合要求"""
        if len(v) < 1:
            raise ValueError('account不能为空')
        if len(v) > 50:
            raise ValueError('account长度不能超过50个字符')
        return v
    
    @field_validator('url')
    @classmethod
    def validate_url(cls, v: str) -> str:
        """验证URL格式"""
        if not v.startswith(('http://', 'https://')):
            raise ValueError('url必须是有效的HTTP或HTTPS地址')
        return v

@router.post("/api/proxy_download")
async def proxy_download(req_data: ProxyDownloadRequest, db: Session = Depends(get_db)):
    """云端代理:校验所有权后,拉取真实的工作流 JSON 文件并透传给本地"""
    items_db = json_db.load_data("items.json", default_data=[])
    item = next((i for i in items_db if i["id"] == req_data.item_id), None)
    
    if not item: 
        return JSONResponse(content={"error": "云端记录中找不到该资源"}, status_code=404)
    
    # 1. 核心鉴权:验证用户是否购买/获取过该工作流
    price = int(item.get("price", 0))
    if price > 0 and req_data.account != item.get("author"):
        owned = db.query(Ownership).filter(Ownership.account == req_data.account, Ownership.item_id == req_data.item_id).first()
        if not owned:
            return JSONResponse(content={"error": "您尚未获取该工作流,请先在社区页面点击获取"}, status_code=403)

    target_url = req_data.url
    
    # 🚀 核心修复:从环境变量提取 Hugging Face Token,并组装 Authorization 请求头
    hf_token = os.environ.get("HF_TOKEN") 
    
    headers = {"User-Agent": "ComfyUI-Ranking-SaaS"}
    if hf_token and "huggingface.co" in target_url:
        headers["Authorization"] = f"Bearer {hf_token}"
    
    # 🚀 核心修复:使用异步 httpx 替代同步 urllib,避免阻塞事件循环
    # ⚠️ 安全警告:仅在网络环境导致证书验证失败时,通过环境变量临时关闭
    verify_ssl = os.environ.get("DISABLE_SSL_VERIFY", "").lower() not in ("1", "true")
    if not verify_ssl:
        print("⚠️ SSL证书验证已关闭,请仅在调试环境使用")
    
    try:
        async with httpx.AsyncClient(follow_redirects=True, verify=verify_ssl, timeout=120.0) as client:
            print(f"🔍 开始下载资源 [{req_data.item_id}]")
            print(f"🔗 目标地址:{target_url[:80]}...")
            
            response = await client.get(target_url, headers=headers)
            
            if response.status_code != 200:
                print(f"❌ 源文件拉取失败 [HTTP {response.status_code}]")
                return JSONResponse(
                    content={"error": f"源文件拉取失败,HTTP 状态码:{response.status_code}"}, 
                    status_code=response.status_code
                )
            
            content = response.content
            print(f"✅ 成功下载资源 [{req_data.item_id}], 大小:{len(content)} bytes")
            
            return Response(
                content=content, 
                media_type="application/json", 
                status_code=200,
                headers={"Content-Disposition": f"attachment; filename={req_data.item_id}.json"}
            )
            
    except httpx.TimeoutException as e:
        print(f"❌ 下载超时:{str(e)}")
        return JSONResponse(content={"error": "下载超时,请稍后重试"}, status_code=504)
    except Exception as e:
        import traceback
        print(f"❌ 代理下载异常:{str(e)}")
        print(traceback.format_exc())
        return JSONResponse(content={"error": f"代理下载时发生网络异常:{str(e)}"}, status_code=500)