|
|
import json |
|
|
import uuid |
|
|
import urllib.request |
|
|
import urllib.parse |
|
|
import urllib.error |
|
|
import os |
|
|
import random |
|
|
import time |
|
|
import shutil |
|
|
import asyncio |
|
|
import requests |
|
|
import httpx |
|
|
from typing import List, Dict, Any, Optional |
|
|
from threading import Lock |
|
|
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, UploadFile, File |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from fastapi.responses import FileResponse, Response |
|
|
from pydantic import BaseModel |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
class ConnectionManager: |
|
|
def __init__(self): |
|
|
self.active_connections: List[WebSocket] = [] |
|
|
self.user_connections: Dict[str, WebSocket] = {} |
|
|
|
|
|
async def connect(self, websocket: WebSocket, client_id: str = None): |
|
|
await websocket.accept() |
|
|
self.active_connections.append(websocket) |
|
|
if client_id: |
|
|
self.user_connections[client_id] = websocket |
|
|
print(f"WS Connected. Total: {len(self.active_connections)}") |
|
|
await self.broadcast_count() |
|
|
|
|
|
async def disconnect(self, websocket: WebSocket, client_id: str = None): |
|
|
if websocket in self.active_connections: |
|
|
self.active_connections.remove(websocket) |
|
|
if client_id and client_id in self.user_connections: |
|
|
del self.user_connections[client_id] |
|
|
print(f"WS Disconnected. Total: {len(self.active_connections)}") |
|
|
await self.broadcast_count() |
|
|
|
|
|
async def send_personal_message(self, message: dict, client_id: str): |
|
|
if client_id in self.user_connections: |
|
|
try: |
|
|
await self.user_connections[client_id].send_text(json.dumps(message)) |
|
|
except Exception as e: |
|
|
print(f"WS Send Error ({client_id}): {e}") |
|
|
self.disconnect(self.user_connections[client_id], client_id) |
|
|
|
|
|
async def broadcast_count(self): |
|
|
count = len(self.active_connections) |
|
|
data = json.dumps({"type": "stats", "online_count": count}) |
|
|
print(f"Broadcasting online count: {count}") |
|
|
|
|
|
for connection in self.active_connections[:]: |
|
|
try: |
|
|
await connection.send_text(data) |
|
|
except Exception as e: |
|
|
print(f"Broadcast error for client {id(connection)}: {e}") |
|
|
self.active_connections.remove(connection) |
|
|
|
|
|
async def broadcast_new_image(self, image_data: dict): |
|
|
"""广播新生成的图片数据给所有客户端""" |
|
|
data = json.dumps({"type": "new_image", "data": image_data}) |
|
|
print(f"Broadcasting new image to {len(self.active_connections)} clients") |
|
|
for connection in self.active_connections[:]: |
|
|
try: |
|
|
await connection.send_text(data) |
|
|
except Exception as e: |
|
|
print(f"Broadcast image error for client {id(connection)}: {e}") |
|
|
self.active_connections.remove(connection) |
|
|
|
|
|
manager = ConnectionManager() |
|
|
|
|
|
|
|
|
GLOBAL_LOOP = None |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
global GLOBAL_LOOP |
|
|
GLOBAL_LOOP = asyncio.get_running_loop() |
|
|
|
|
|
@app.websocket("/ws/stats") |
|
|
async def websocket_endpoint(websocket: WebSocket, client_id: str = None): |
|
|
await manager.connect(websocket, client_id) |
|
|
try: |
|
|
while True: |
|
|
|
|
|
data = await websocket.receive_text() |
|
|
if data == "ping": |
|
|
await websocket.send_text(json.dumps({"type": "pong"})) |
|
|
except WebSocketDisconnect: |
|
|
print(f"WebSocket disconnected normally: {id(websocket)}") |
|
|
await manager.disconnect(websocket, client_id) |
|
|
except Exception as e: |
|
|
print(f"WS Error for {id(websocket)}: {e}") |
|
|
await manager.disconnect(websocket, client_id) |
|
|
|
|
|
|
|
|
|
|
|
COMFYUI_INSTANCES = [ |
|
|
"127.0.0.1:8188", |
|
|
"127.0.0.1:4090", |
|
|
] |
|
|
|
|
|
COMFYUI_ADDRESS = COMFYUI_INSTANCES[0] |
|
|
|
|
|
CLIENT_ID = str(uuid.uuid4()) |
|
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
WORKFLOW_DIR = os.path.join(BASE_DIR, "workflows") |
|
|
WORKFLOW_PATH = os.path.join(WORKFLOW_DIR, "Z-Image.json") |
|
|
STATIC_DIR = os.path.join(BASE_DIR, "static") |
|
|
OUTPUT_DIR = os.path.join(BASE_DIR, "output") |
|
|
HISTORY_FILE = os.path.join(BASE_DIR, "history.json") |
|
|
QUEUE = [] |
|
|
QUEUE_LOCK = Lock() |
|
|
HISTORY_LOCK = Lock() |
|
|
|
|
|
|
|
|
NEXT_TASK_ID = 1 |
|
|
|
|
|
|
|
|
BACKEND_LOCAL_LOAD = {addr: 0 for addr in COMFYUI_INSTANCES} |
|
|
LOAD_LOCK = Lock() |
|
|
|
|
|
|
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
os.makedirs(STATIC_DIR, exist_ok=True) |
|
|
os.makedirs(WORKFLOW_DIR, exist_ok=True) |
|
|
|
|
|
GLOBAL_CONFIG_FILE = os.path.join(BASE_DIR, "global_config.json") |
|
|
GLOBAL_CONFIG_LOCK = Lock() |
|
|
|
|
|
|
|
|
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") |
|
|
app.mount("/output", StaticFiles(directory=OUTPUT_DIR), name="output") |
|
|
|
|
|
class GenerateRequest(BaseModel): |
|
|
prompt: str = "" |
|
|
width: int = 1024 |
|
|
height: int = 1024 |
|
|
workflow_json: str = "Z-Image.json" |
|
|
params: Dict[str, Any] = {} |
|
|
type: str = "zimage" |
|
|
client_id: str = "" |
|
|
|
|
|
class CloudGenRequest(BaseModel): |
|
|
prompt: str |
|
|
api_key: str |
|
|
resolution: str = "1024x1024" |
|
|
client_id: Optional[str] = None |
|
|
type: str = "default" |
|
|
image_urls: List[str] = [] |
|
|
model: str = "" |
|
|
|
|
|
class DeleteHistoryRequest(BaseModel): |
|
|
timestamp: float |
|
|
|
|
|
|
|
|
|
|
|
def get_best_backend(): |
|
|
"""选择队列压力最小的后端""" |
|
|
best_backend = COMFYUI_INSTANCES[0] |
|
|
min_queue_size = float('inf') |
|
|
|
|
|
for addr in COMFYUI_INSTANCES: |
|
|
try: |
|
|
|
|
|
with urllib.request.urlopen(f"http://{addr}/queue", timeout=1) as response: |
|
|
data = json.loads(response.read()) |
|
|
|
|
|
remote_load = len(data.get('queue_running', [])) + len(data.get('queue_pending', [])) |
|
|
|
|
|
|
|
|
with LOAD_LOCK: |
|
|
local_load = BACKEND_LOCAL_LOAD.get(addr, 0) |
|
|
|
|
|
|
|
|
|
|
|
effective_load = max(remote_load, local_load) |
|
|
|
|
|
print(f"Backend {addr} load: {effective_load} (Remote: {remote_load}, Local: {local_load})") |
|
|
|
|
|
if effective_load < min_queue_size: |
|
|
min_queue_size = effective_load |
|
|
best_backend = addr |
|
|
except Exception as e: |
|
|
print(f"Backend {addr} unreachable: {e}") |
|
|
continue |
|
|
|
|
|
print(f"Selected backend: {best_backend}") |
|
|
return best_backend |
|
|
|
|
|
|
|
|
|
|
|
def download_image(comfy_address, comfy_url_path, prefix="studio_"): |
|
|
"""将远程 ComfyUI 图片保存到本地并返回相对路径""" |
|
|
filename = f"{prefix}{uuid.uuid4().hex[:10]}.png" |
|
|
local_path = os.path.join(OUTPUT_DIR, filename) |
|
|
full_url = f"http://{comfy_address}{comfy_url_path}" |
|
|
try: |
|
|
with urllib.request.urlopen(full_url) as response, open(local_path, 'wb') as out_file: |
|
|
shutil.copyfileobj(response, out_file) |
|
|
return f"/output/{filename}" |
|
|
except Exception as e: |
|
|
print(f"下载图片失败: {e} (URL: {full_url})") |
|
|
|
|
|
|
|
|
|
|
|
if comfy_url_path.startswith("/view"): |
|
|
return comfy_url_path.replace("/view", "/api/view", 1) |
|
|
return full_url |
|
|
|
|
|
def save_to_history(record): |
|
|
"""保存记录到 JSON 文件""" |
|
|
with HISTORY_LOCK: |
|
|
history = [] |
|
|
if os.path.exists(HISTORY_FILE): |
|
|
try: |
|
|
with open(HISTORY_FILE, 'r', encoding='utf-8') as f: |
|
|
history = json.load(f) |
|
|
except: pass |
|
|
|
|
|
|
|
|
if "timestamp" not in record: |
|
|
record["timestamp"] = time.time() |
|
|
|
|
|
history.insert(0, record) |
|
|
|
|
|
with open(HISTORY_FILE, 'w', encoding='utf-8') as f: |
|
|
json.dump(history[:5000], f, ensure_ascii=False, indent=4) |
|
|
|
|
|
def get_comfy_history(comfy_address, prompt_id): |
|
|
try: |
|
|
with urllib.request.urlopen(f"http://{comfy_address}/history/{prompt_id}") as response: |
|
|
return json.loads(response.read()) |
|
|
except Exception as e: |
|
|
|
|
|
return {} |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/api/view") |
|
|
def view_image(filename: str, type: str = "input", subfolder: str = ""): |
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
url = f"http://{COMFYUI_INSTANCES[0]}/view" |
|
|
params = {"filename": filename, "type": type, "subfolder": subfolder} |
|
|
r = requests.get(url, params=params) |
|
|
return Response(content=r.content, media_type=r.headers.get('Content-Type')) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=404, detail="Image not found") |
|
|
|
|
|
@app.post("/api/upload") |
|
|
async def upload_image(files: List[UploadFile] = File(...)): |
|
|
uploaded_files = [] |
|
|
|
|
|
|
|
|
files_content = [] |
|
|
for file in files: |
|
|
content = await file.read() |
|
|
files_content.append((file, content)) |
|
|
|
|
|
|
|
|
for file, content in files_content: |
|
|
success_count = 0 |
|
|
last_result = None |
|
|
|
|
|
for addr in COMFYUI_INSTANCES: |
|
|
try: |
|
|
|
|
|
files_data = {'image': (file.filename, content, file.content_type)} |
|
|
|
|
|
|
|
|
response = requests.post(f"http://{addr}/upload/image", files=files_data, timeout=5) |
|
|
|
|
|
if response.status_code == 200: |
|
|
last_result = response.json() |
|
|
success_count += 1 |
|
|
else: |
|
|
print(f"Upload to {addr} failed: {response.text}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Upload error for {addr}: {e}") |
|
|
|
|
|
if success_count > 0 and last_result: |
|
|
uploaded_files.append({"comfy_name": last_result.get("name", file.filename)}) |
|
|
else: |
|
|
raise HTTPException(status_code=500, detail=f"Failed to upload to any backend") |
|
|
|
|
|
return {"files": uploaded_files} |
|
|
|
|
|
@app.get("/") |
|
|
async def index(): |
|
|
return FileResponse(os.path.join(STATIC_DIR, "index.html")) |
|
|
|
|
|
@app.get("/api/history") |
|
|
async def get_history_api(type: str = None): |
|
|
if os.path.exists(HISTORY_FILE): |
|
|
try: |
|
|
with open(HISTORY_FILE, 'r', encoding='utf-8') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
|
|
|
if type: |
|
|
|
|
|
target_types = [type] |
|
|
if type == "zimage": |
|
|
target_types.append("cloud") |
|
|
|
|
|
data = [item for item in data if item.get("type", "zimage") in target_types] |
|
|
|
|
|
|
|
|
data = [item for item in data if item.get("images") and len(item["images"]) > 0] |
|
|
|
|
|
|
|
|
|
|
|
def sort_key(item): |
|
|
ts = item.get("timestamp", 0) |
|
|
if isinstance(ts, (int, float)): |
|
|
return float(ts) |
|
|
return 0 |
|
|
|
|
|
data.sort(key=sort_key, reverse=True) |
|
|
|
|
|
|
|
|
for item in data: |
|
|
if "is_cloud" not in item and item.get("images"): |
|
|
|
|
|
if any("cloud_angle" in img or "cloud_" in img for img in item["images"]): |
|
|
item["is_cloud"] = True |
|
|
|
|
|
return data |
|
|
except Exception as e: |
|
|
print(f"读取历史文件失败: {e}") |
|
|
return [] |
|
|
return [] |
|
|
|
|
|
@app.get("/api/queue_status") |
|
|
async def get_queue_status(client_id: str): |
|
|
with QUEUE_LOCK: |
|
|
total = len(QUEUE) |
|
|
positions = [i + 1 for i, t in enumerate(QUEUE) if t["client_id"] == client_id] |
|
|
position = positions[0] if positions else 0 |
|
|
return {"total": total, "position": position} |
|
|
|
|
|
@app.post("/api/history/delete") |
|
|
async def delete_history(req: DeleteHistoryRequest): |
|
|
if not os.path.exists(HISTORY_FILE): |
|
|
return {"success": False, "message": "History file not found"} |
|
|
|
|
|
try: |
|
|
with HISTORY_LOCK: |
|
|
with open(HISTORY_FILE, 'r', encoding='utf-8') as f: |
|
|
history = json.load(f) |
|
|
|
|
|
|
|
|
target_record = None |
|
|
new_history = [] |
|
|
for item in history: |
|
|
is_match = False |
|
|
item_ts = item.get("timestamp", 0) |
|
|
|
|
|
|
|
|
if isinstance(req.timestamp, (int, float)) and isinstance(item_ts, (int, float)): |
|
|
if abs(float(item_ts) - float(req.timestamp)) < 0.001: |
|
|
is_match = True |
|
|
|
|
|
elif str(item_ts) == str(req.timestamp): |
|
|
is_match = True |
|
|
|
|
|
if is_match: |
|
|
target_record = item |
|
|
else: |
|
|
new_history.append(item) |
|
|
|
|
|
if target_record: |
|
|
|
|
|
with open(HISTORY_FILE, 'w', encoding='utf-8') as f: |
|
|
json.dump(new_history, f, ensure_ascii=False, indent=4) |
|
|
|
|
|
|
|
|
if target_record: |
|
|
for img_url in target_record.get("images", []): |
|
|
|
|
|
if img_url.startswith("/output/"): |
|
|
filename = img_url.split("/")[-1] |
|
|
file_path = os.path.join(OUTPUT_DIR, filename) |
|
|
if os.path.exists(file_path): |
|
|
try: |
|
|
os.remove(file_path) |
|
|
except Exception as e: |
|
|
print(f"Failed to delete file {file_path}: {e}") |
|
|
|
|
|
return {"success": True} |
|
|
else: |
|
|
return {"success": False, "message": "Record not found"} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Delete history error: {e}") |
|
|
return {"success": False, "message": str(e)} |
|
|
|
|
|
class TokenRequest(BaseModel): |
|
|
token: str |
|
|
|
|
|
@app.get("/api/config/token") |
|
|
async def get_global_token(): |
|
|
if os.path.exists(GLOBAL_CONFIG_FILE): |
|
|
try: |
|
|
with open(GLOBAL_CONFIG_FILE, 'r', encoding='utf-8') as f: |
|
|
config = json.load(f) |
|
|
return {"token": config.get("modelscope_token", "")} |
|
|
except: |
|
|
return {"token": ""} |
|
|
return {"token": ""} |
|
|
|
|
|
@app.post("/api/config/token") |
|
|
async def set_global_token(req: TokenRequest): |
|
|
with GLOBAL_CONFIG_LOCK: |
|
|
config = {} |
|
|
if os.path.exists(GLOBAL_CONFIG_FILE): |
|
|
try: |
|
|
with open(GLOBAL_CONFIG_FILE, 'r', encoding='utf-8') as f: |
|
|
config = json.load(f) |
|
|
except: pass |
|
|
|
|
|
config["modelscope_token"] = req.token.strip() |
|
|
|
|
|
with open(GLOBAL_CONFIG_FILE, 'w', encoding='utf-8') as f: |
|
|
json.dump(config, f, indent=4) |
|
|
return {"success": True} |
|
|
|
|
|
@app.delete("/api/config/token") |
|
|
async def delete_global_token(): |
|
|
with GLOBAL_CONFIG_LOCK: |
|
|
if os.path.exists(GLOBAL_CONFIG_FILE): |
|
|
try: |
|
|
config = {} |
|
|
with open(GLOBAL_CONFIG_FILE, 'r', encoding='utf-8') as f: |
|
|
config = json.load(f) |
|
|
|
|
|
if "modelscope_token" in config: |
|
|
del config["modelscope_token"] |
|
|
with open(GLOBAL_CONFIG_FILE, 'w', encoding='utf-8') as f: |
|
|
json.dump(config, f, indent=4) |
|
|
except: pass |
|
|
return {"success": True} |
|
|
|
|
|
class CloudPollRequest(BaseModel): |
|
|
task_id: str |
|
|
api_key: str |
|
|
client_id: Optional[str] = None |
|
|
|
|
|
@app.post("/api/angle/poll_status") |
|
|
async def poll_angle_cloud(req: CloudPollRequest): |
|
|
""" |
|
|
Resume polling for an existing Angle task. |
|
|
""" |
|
|
base_url = 'https://api-inference.modelscope.cn/' |
|
|
clean_token = req.api_key.strip() |
|
|
|
|
|
headers = { |
|
|
"Authorization": f"Bearer {clean_token}", |
|
|
"Content-Type": "application/json", |
|
|
"X-ModelScope-Async-Mode": "true" |
|
|
} |
|
|
|
|
|
task_id = req.task_id |
|
|
print(f"Resuming polling for Angle Task: {task_id}") |
|
|
|
|
|
try: |
|
|
async with httpx.AsyncClient(timeout=30) as client: |
|
|
|
|
|
for i in range(300): |
|
|
await asyncio.sleep(2) |
|
|
try: |
|
|
result = await client.get( |
|
|
f"{base_url}v1/tasks/{task_id}", |
|
|
headers={**headers, "X-ModelScope-Task-Type": "image_generation"}, |
|
|
) |
|
|
data = result.json() |
|
|
status = data.get("task_status") |
|
|
|
|
|
if status == "SUCCEED": |
|
|
img_url = data["output_images"][0] |
|
|
print(f"Angle Task SUCCEED: {img_url}") |
|
|
|
|
|
if req.client_id: |
|
|
await manager.send_personal_message({ |
|
|
"type": "cloud_status", |
|
|
"status": "SUCCEED", |
|
|
"task_id": task_id |
|
|
}, req.client_id) |
|
|
|
|
|
|
|
|
local_path = "" |
|
|
try: |
|
|
async with httpx.AsyncClient() as dl_client: |
|
|
img_res = await dl_client.get(img_url) |
|
|
if img_res.status_code == 200: |
|
|
filename = f"cloud_angle_{int(time.time())}.png" |
|
|
file_path = os.path.join(OUTPUT_DIR, filename) |
|
|
with open(file_path, "wb") as f: |
|
|
f.write(img_res.content) |
|
|
local_path = f"/output/{filename}" |
|
|
else: |
|
|
local_path = img_url |
|
|
except Exception: |
|
|
local_path = img_url |
|
|
|
|
|
record = { |
|
|
"timestamp": time.time(), |
|
|
"prompt": f"Resumed {task_id}", |
|
|
"images": [local_path], |
|
|
"type": "angle" |
|
|
} |
|
|
save_to_history(record) |
|
|
return {"url": local_path} |
|
|
|
|
|
elif status == "FAILED": |
|
|
if req.client_id: |
|
|
await manager.send_personal_message({ |
|
|
"type": "cloud_status", |
|
|
"status": "FAILED", |
|
|
"task_id": task_id |
|
|
}, req.client_id) |
|
|
raise Exception(f"ModelScope task failed: {data}") |
|
|
|
|
|
if i % 5 == 0: |
|
|
print(f"Angle Task {task_id} status: {status} ({i}/150)") |
|
|
if req.client_id: |
|
|
await manager.send_personal_message({ |
|
|
"type": "cloud_status", |
|
|
"status": f"{status} ({i}/150)", |
|
|
"task_id": task_id, |
|
|
"progress": i, |
|
|
"total": 150 |
|
|
}, req.client_id) |
|
|
|
|
|
except Exception as loop_e: |
|
|
print(f"Angle polling error: {loop_e}") |
|
|
continue |
|
|
|
|
|
print(f"Angle Task Timeout Again: {task_id}") |
|
|
if req.client_id: |
|
|
await manager.send_personal_message({ |
|
|
"type": "cloud_status", |
|
|
"status": "TIMEOUT", |
|
|
"task_id": task_id |
|
|
}, req.client_id) |
|
|
|
|
|
return {"status": "timeout", "task_id": task_id, "message": "Task still pending"} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Angle polling error: {e}") |
|
|
raise HTTPException(status_code=400, detail=str(e)) |
|
|
|
|
|
@app.post("/api/angle/generate") |
|
|
async def generate_angle_cloud(req: CloudGenRequest): |
|
|
""" |
|
|
Dedicated endpoint for Angle/Qwen-Image-Edit tasks. |
|
|
Logic mirrors test/main.py but uses async httpx. |
|
|
""" |
|
|
base_url = 'https://api-inference.modelscope.cn/' |
|
|
clean_token = req.api_key.strip() |
|
|
|
|
|
headers = { |
|
|
"Authorization": f"Bearer {clean_token}", |
|
|
"Content-Type": "application/json", |
|
|
"X-ModelScope-Async-Mode": "true" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
payload = { |
|
|
"model": "Qwen/Qwen-Image-Edit-2511", |
|
|
"prompt": req.prompt.strip(), |
|
|
"image_url": req.image_urls |
|
|
} |
|
|
|
|
|
print(f"Angle Cloud Request: {payload['model']}, Prompt: {payload['prompt'][:20]}...") |
|
|
|
|
|
try: |
|
|
async with httpx.AsyncClient(timeout=30) as client: |
|
|
|
|
|
submit_res = await client.post( |
|
|
f"{base_url}v1/images/generations", |
|
|
headers=headers, |
|
|
json=payload |
|
|
) |
|
|
|
|
|
if submit_res.status_code != 200: |
|
|
try: |
|
|
detail = submit_res.json() |
|
|
except: |
|
|
detail = submit_res.text |
|
|
print(f"Angle Submit Error: {detail}") |
|
|
raise HTTPException(status_code=submit_res.status_code, detail=detail) |
|
|
|
|
|
task_id = submit_res.json().get("task_id") |
|
|
print(f"Angle Task Submitted, ID: {task_id}") |
|
|
|
|
|
|
|
|
if req.client_id: |
|
|
await manager.send_personal_message({ |
|
|
"type": "cloud_status", |
|
|
"status": "SUBMITTED", |
|
|
"task_id": task_id, |
|
|
"progress": 0, |
|
|
"total": 150 |
|
|
}, req.client_id) |
|
|
|
|
|
|
|
|
for i in range(300): |
|
|
await asyncio.sleep(2) |
|
|
try: |
|
|
result = await client.get( |
|
|
f"{base_url}v1/tasks/{task_id}", |
|
|
headers={**headers, "X-ModelScope-Task-Type": "image_generation"}, |
|
|
) |
|
|
data = result.json() |
|
|
status = data.get("task_status") |
|
|
|
|
|
if status == "SUCCEED": |
|
|
img_url = data["output_images"][0] |
|
|
print(f"Angle Task SUCCEED: {img_url}") |
|
|
|
|
|
|
|
|
if req.client_id: |
|
|
await manager.send_personal_message({ |
|
|
"type": "cloud_status", |
|
|
"status": "SUCCEED", |
|
|
"task_id": task_id |
|
|
}, req.client_id) |
|
|
|
|
|
|
|
|
local_path = "" |
|
|
try: |
|
|
|
|
|
async with httpx.AsyncClient() as dl_client: |
|
|
img_res = await dl_client.get(img_url) |
|
|
if img_res.status_code == 200: |
|
|
filename = f"cloud_angle_{int(time.time())}.png" |
|
|
file_path = os.path.join(OUTPUT_DIR, filename) |
|
|
with open(file_path, "wb") as f: |
|
|
f.write(img_res.content) |
|
|
local_path = f"/output/{filename}" |
|
|
print(f"Angle Image saved: {local_path}") |
|
|
else: |
|
|
local_path = img_url |
|
|
except Exception as dl_e: |
|
|
print(f"Download error: {dl_e}") |
|
|
local_path = img_url |
|
|
|
|
|
|
|
|
record = { |
|
|
"timestamp": time.time(), |
|
|
"prompt": req.prompt, |
|
|
"images": [local_path], |
|
|
"type": "angle", |
|
|
"is_cloud": True |
|
|
} |
|
|
save_to_history(record) |
|
|
return {"url": local_path} |
|
|
|
|
|
elif status == "FAILED": |
|
|
if req.client_id: |
|
|
await manager.send_personal_message({ |
|
|
"type": "cloud_status", |
|
|
"status": "FAILED", |
|
|
"task_id": task_id |
|
|
}, req.client_id) |
|
|
raise Exception(f"ModelScope task failed: {data}") |
|
|
|
|
|
|
|
|
if i % 5 == 0: |
|
|
print(f"Angle Task {task_id} status: {status} ({i}/150)") |
|
|
if req.client_id: |
|
|
await manager.send_personal_message({ |
|
|
"type": "cloud_status", |
|
|
"status": f"{status} ({i}/150)", |
|
|
"task_id": task_id, |
|
|
"progress": i, |
|
|
"total": 150 |
|
|
}, req.client_id) |
|
|
|
|
|
except Exception as loop_e: |
|
|
|
|
|
print(f"Angle polling error (retrying): {loop_e}") |
|
|
continue |
|
|
|
|
|
|
|
|
print(f"Angle Task Timeout: {task_id}") |
|
|
if req.client_id: |
|
|
await manager.send_personal_message({ |
|
|
"type": "cloud_status", |
|
|
"status": "TIMEOUT", |
|
|
"task_id": task_id |
|
|
}, req.client_id) |
|
|
|
|
|
|
|
|
return {"status": "timeout", "task_id": task_id, "message": "Task still pending"} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Angle generation error: {e}") |
|
|
raise HTTPException(status_code=400, detail=str(e)) |
|
|
|
|
|
|
|
|
@app.post("/generate") |
|
|
async def generate_cloud(req: CloudGenRequest): |
|
|
base_url = 'https://api-inference.modelscope.cn/' |
|
|
clean_token = req.api_key.strip() |
|
|
|
|
|
headers = { |
|
|
"Authorization": f"Bearer {clean_token}", |
|
|
"Content-Type": "application/json", |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
payload = { |
|
|
"model": "Tongyi-MAI/Z-Image-Turbo", |
|
|
"prompt": req.prompt.strip(), |
|
|
"size": req.resolution, |
|
|
"n": 1 |
|
|
} |
|
|
|
|
|
try: |
|
|
async with httpx.AsyncClient(timeout=30) as client: |
|
|
|
|
|
print(f"Submitting ModelScope task for prompt: {req.prompt[:20]}...") |
|
|
|
|
|
|
|
|
|
|
|
submit_res = await client.post( |
|
|
f"{base_url}v1/images/generations", |
|
|
headers={**headers, "X-ModelScope-Async-Mode": "true"}, |
|
|
json=payload |
|
|
) |
|
|
|
|
|
if submit_res.status_code != 200: |
|
|
|
|
|
try: |
|
|
detail = submit_res.json() |
|
|
except: |
|
|
detail = submit_res.text |
|
|
print(f"ModelScope Submit Error: {detail}") |
|
|
raise HTTPException(status_code=submit_res.status_code, detail=detail) |
|
|
|
|
|
task_id = submit_res.json().get("task_id") |
|
|
print(f"Task submitted, ID: {task_id}") |
|
|
|
|
|
|
|
|
|
|
|
for i in range(200): |
|
|
await asyncio.sleep(3) |
|
|
try: |
|
|
result = await client.get( |
|
|
f"{base_url}v1/tasks/{task_id}", |
|
|
headers={**headers, "X-ModelScope-Task-Type": "image_generation"}, |
|
|
) |
|
|
data = result.json() |
|
|
status = data.get("task_status") |
|
|
|
|
|
if i % 5 == 0: |
|
|
print(f"Task {task_id} status check {i}: {status}") |
|
|
|
|
|
if status == "SUCCEED": |
|
|
img_url = data["output_images"][0] |
|
|
print(f"Task {task_id} SUCCEED: {img_url}") |
|
|
|
|
|
|
|
|
local_path = "" |
|
|
try: |
|
|
|
|
|
async with httpx.AsyncClient() as dl_client: |
|
|
img_res = await dl_client.get(img_url) |
|
|
if img_res.status_code == 200: |
|
|
filename = f"cloud_{int(time.time())}.png" |
|
|
file_path = os.path.join(OUTPUT_DIR, filename) |
|
|
with open(file_path, "wb") as f: |
|
|
f.write(img_res.content) |
|
|
local_path = f"/output/{filename}" |
|
|
print(f"Image saved locally: {local_path}") |
|
|
else: |
|
|
print(f"Failed to download image: {img_res.status_code}") |
|
|
local_path = img_url |
|
|
except Exception as dl_e: |
|
|
print(f"Download error: {dl_e}") |
|
|
local_path = img_url |
|
|
|
|
|
|
|
|
|
|
|
record = { |
|
|
"timestamp": time.time(), |
|
|
"prompt": req.prompt, |
|
|
"images": [local_path], |
|
|
"type": "cloud" |
|
|
} |
|
|
save_to_history(record) |
|
|
|
|
|
|
|
|
try: |
|
|
await manager.broadcast_new_image(record) |
|
|
except Exception as e: |
|
|
print(f"Broadcast error: {e}") |
|
|
|
|
|
return {"url": local_path} |
|
|
|
|
|
elif status == "FAILED": |
|
|
raise Exception(f"ModelScope task failed: {data}") |
|
|
except Exception as loop_e: |
|
|
print(f"Polling error (retrying): {loop_e}") |
|
|
continue |
|
|
|
|
|
raise Exception("Cloud generation timeout (180s)") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Cloud generation error: {e}") |
|
|
raise HTTPException(status_code=400, detail=str(e)) |
|
|
|
|
|
@app.post("/api/generate") |
|
|
def generate(req: GenerateRequest): |
|
|
global NEXT_TASK_ID |
|
|
|
|
|
|
|
|
current_task = None |
|
|
target_backend = None |
|
|
with QUEUE_LOCK: |
|
|
task_id = NEXT_TASK_ID |
|
|
NEXT_TASK_ID += 1 |
|
|
current_task = {"task_id": task_id, "client_id": req.client_id} |
|
|
QUEUE.append(current_task) |
|
|
|
|
|
try: |
|
|
|
|
|
target_backend = get_best_backend() |
|
|
|
|
|
|
|
|
with LOAD_LOCK: |
|
|
BACKEND_LOCAL_LOAD[target_backend] += 1 |
|
|
|
|
|
|
|
|
workflow_path = os.path.join(WORKFLOW_DIR, req.workflow_json) |
|
|
|
|
|
|
|
|
if not os.path.exists(workflow_path) and req.workflow_json == "Z-Image.json": |
|
|
workflow_path = WORKFLOW_PATH |
|
|
|
|
|
if not os.path.exists(workflow_path): |
|
|
raise Exception(f"Workflow file not found: {req.workflow_json}") |
|
|
|
|
|
with open(workflow_path, 'r', encoding='utf-8') as f: |
|
|
workflow = json.load(f) |
|
|
|
|
|
seed = random.randint(1, 10**15) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "23" in workflow and req.prompt: |
|
|
workflow["23"]["inputs"]["text"] = req.prompt |
|
|
if "144" in workflow: |
|
|
workflow["144"]["inputs"]["width"] = req.width |
|
|
workflow["144"]["inputs"]["height"] = req.height |
|
|
if "22" in workflow: |
|
|
workflow["22"]["inputs"]["seed"] = seed |
|
|
|
|
|
if "158" in workflow: |
|
|
workflow["158"]["inputs"]["noise_seed"] = seed |
|
|
|
|
|
for node_id in ["146", "181"]: |
|
|
if node_id in workflow and "inputs" in workflow[node_id] and "seed" in workflow[node_id]["inputs"]: |
|
|
workflow[node_id]["inputs"]["seed"] = seed |
|
|
|
|
|
if "184" in workflow and "inputs" in workflow["184"] and "seed" in workflow["184"]["inputs"]: |
|
|
workflow["184"]["inputs"]["seed"] = seed |
|
|
|
|
|
if "172" in workflow and "inputs" in workflow["172"] and "seed" in workflow["172"]["inputs"]: |
|
|
|
|
|
workflow["172"]["inputs"]["seed"] = seed % 4294967295 |
|
|
|
|
|
if "14" in workflow and "inputs" in workflow["14"] and "seed" in workflow["14"]["inputs"]: |
|
|
workflow["14"]["inputs"]["seed"] = seed |
|
|
|
|
|
|
|
|
for node_id, node_inputs in req.params.items(): |
|
|
if node_id in workflow: |
|
|
if "inputs" not in workflow[node_id]: |
|
|
workflow[node_id]["inputs"] = {} |
|
|
for input_name, value in node_inputs.items(): |
|
|
workflow[node_id]["inputs"][input_name] = value |
|
|
|
|
|
|
|
|
p = {"prompt": workflow, "client_id": CLIENT_ID} |
|
|
data = json.dumps(p).encode('utf-8') |
|
|
try: |
|
|
post_req = urllib.request.Request(f"http://{target_backend}/prompt", data=data) |
|
|
prompt_id = json.loads(urllib.request.urlopen(post_req, timeout=10).read())['prompt_id'] |
|
|
except urllib.error.HTTPError as e: |
|
|
error_body = e.read().decode('utf-8') |
|
|
print(f"ComfyUI API Error ({e.code}): {error_body}") |
|
|
raise Exception(f"HTTP Error {e.code}: {error_body}") |
|
|
except Exception as e: |
|
|
raise e |
|
|
|
|
|
|
|
|
history_data = None |
|
|
for i in range(300): |
|
|
try: |
|
|
res = get_comfy_history(target_backend, prompt_id) |
|
|
if prompt_id in res: |
|
|
history_data = res[prompt_id] |
|
|
break |
|
|
except Exception as e: |
|
|
pass |
|
|
|
|
|
time.sleep(1) |
|
|
|
|
|
if not history_data: |
|
|
raise Exception("ComfyUI 渲染超时") |
|
|
|
|
|
|
|
|
local_urls = [] |
|
|
current_timestamp = time.time() |
|
|
|
|
|
if 'outputs' in history_data: |
|
|
for node_id in history_data['outputs']: |
|
|
node_output = history_data['outputs'][node_id] |
|
|
if 'images' in node_output: |
|
|
for img in node_output['images']: |
|
|
comfy_url_path = f"/view?filename={img['filename']}&subfolder={img['subfolder']}&type={img['type']}" |
|
|
|
|
|
|
|
|
prefix = f"{req.type}_{int(current_timestamp)}_" |
|
|
local_path = download_image(target_backend, comfy_url_path, prefix=prefix) |
|
|
local_urls.append(local_path) |
|
|
|
|
|
|
|
|
result = { |
|
|
"prompt": req.prompt if req.prompt else "Detail Enhance", |
|
|
"images": local_urls, |
|
|
"seed": seed, |
|
|
"timestamp": current_timestamp, |
|
|
"type": req.type, |
|
|
"params": req.params |
|
|
} |
|
|
save_to_history(result) |
|
|
|
|
|
|
|
|
if GLOBAL_LOOP: |
|
|
asyncio.run_coroutine_threadsafe(manager.broadcast_new_image(result), GLOBAL_LOOP) |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
return {"images": [], "error": str(e)} |
|
|
finally: |
|
|
|
|
|
if target_backend: |
|
|
with LOAD_LOCK: |
|
|
if BACKEND_LOCAL_LOAD.get(target_backend, 0) > 0: |
|
|
BACKEND_LOCAL_LOAD[target_backend] -= 1 |
|
|
|
|
|
|
|
|
if current_task: |
|
|
with QUEUE_LOCK: |
|
|
if current_task in QUEUE: |
|
|
QUEUE.remove(current_task) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|