dx8152 commited on
Commit
7ec8df5
·
verified ·
1 Parent(s): 98513c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +307 -15
app.py CHANGED
@@ -10,7 +10,7 @@ import shutil
10
  import asyncio
11
  import requests
12
  import httpx
13
- from typing import List, Dict, Any
14
  from threading import Lock
15
  from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, UploadFile, File
16
  from fastapi.staticfiles import StaticFiles
@@ -32,18 +32,31 @@ app.add_middleware(
32
  class ConnectionManager:
33
  def __init__(self):
34
  self.active_connections: List[WebSocket] = []
 
35
 
36
- async def connect(self, websocket: WebSocket):
37
  await websocket.accept()
38
  self.active_connections.append(websocket)
 
 
39
  print(f"WS Connected. Total: {len(self.active_connections)}")
40
  await self.broadcast_count()
41
 
42
- async def disconnect(self, websocket: WebSocket):
43
  if websocket in self.active_connections:
44
  self.active_connections.remove(websocket)
45
- print(f"WS Disconnected. Total: {len(self.active_connections)}")
46
- await self.broadcast_count()
 
 
 
 
 
 
 
 
 
 
47
 
48
  async def broadcast_count(self):
49
  count = len(self.active_connections)
@@ -79,8 +92,8 @@ async def startup_event():
79
  GLOBAL_LOOP = asyncio.get_running_loop()
80
 
81
  @app.websocket("/ws/stats")
82
- async def websocket_endpoint(websocket: WebSocket):
83
- await manager.connect(websocket)
84
  try:
85
  while True:
86
  # 接收客户端心跳包
@@ -89,10 +102,10 @@ async def websocket_endpoint(websocket: WebSocket):
89
  await websocket.send_text(json.dumps({"type": "pong"}))
90
  except WebSocketDisconnect:
91
  print(f"WebSocket disconnected normally: {id(websocket)}")
92
- await manager.disconnect(websocket)
93
  except Exception as e:
94
  print(f"WS Error for {id(websocket)}: {e}")
95
- await manager.disconnect(websocket)
96
 
97
  # --- 配置区域 ---
98
  # 支持多卡负载均衡:配置多个 ComfyUI 地址
@@ -144,10 +157,12 @@ class GenerateRequest(BaseModel):
144
 
145
  class CloudGenRequest(BaseModel):
146
  prompt: str
147
- api_key: str = ""
148
  resolution: str = "1024x1024"
149
- client_id: str = "default"
150
- type: str = "zimage"
 
 
151
 
152
  class DeleteHistoryRequest(BaseModel):
153
  timestamp: float
@@ -320,6 +335,14 @@ async def get_history_api(type: str = None):
320
  return 0 # 旧数据排在最后
321
 
322
  data.sort(key=sort_key, reverse=True)
 
 
 
 
 
 
 
 
323
  return data
324
  except Exception as e:
325
  print(f"读取历史文件失败: {e}")
@@ -436,6 +459,267 @@ async def delete_global_token():
436
  except: pass
437
  return {"success": True}
438
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  @app.post("/generate")
440
  async def generate_cloud(req: CloudGenRequest):
441
  base_url = 'https://api-inference.modelscope.cn/'
@@ -447,6 +731,11 @@ async def generate_cloud(req: CloudGenRequest):
447
  }
448
 
449
  # 按照官方 Z-Image 标准版参数
 
 
 
 
 
450
  payload = {
451
  "model": "Tongyi-MAI/Z-Image-Turbo",
452
  "prompt": req.prompt.strip(),
@@ -458,10 +747,13 @@ async def generate_cloud(req: CloudGenRequest):
458
  async with httpx.AsyncClient(timeout=30) as client:
459
  # A. 提交异步任务
460
  print(f"Submitting ModelScope task for prompt: {req.prompt[:20]}...")
 
 
 
461
  submit_res = await client.post(
462
  f"{base_url}v1/images/generations",
463
  headers={**headers, "X-ModelScope-Async-Mode": "true"},
464
- content=json.dumps(payload, ensure_ascii=False).encode('utf-8')
465
  )
466
 
467
  if submit_res.status_code != 200:
@@ -477,8 +769,8 @@ async def generate_cloud(req: CloudGenRequest):
477
  print(f"Task submitted, ID: {task_id}")
478
 
479
  # B. 轮询任务状态
480
- # 增加到 60 次轮询 * 3秒 = 180秒 (3分钟) 超时
481
- for i in range(60):
482
  await asyncio.sleep(3)
483
  try:
484
  result = await client.get(
 
10
  import asyncio
11
  import requests
12
  import httpx
13
+ from typing import List, Dict, Any, Optional
14
  from threading import Lock
15
  from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, UploadFile, File
16
  from fastapi.staticfiles import StaticFiles
 
32
  class ConnectionManager:
33
  def __init__(self):
34
  self.active_connections: List[WebSocket] = []
35
+ self.user_connections: Dict[str, WebSocket] = {}
36
 
37
+ async def connect(self, websocket: WebSocket, client_id: str = None):
38
  await websocket.accept()
39
  self.active_connections.append(websocket)
40
+ if client_id:
41
+ self.user_connections[client_id] = websocket
42
  print(f"WS Connected. Total: {len(self.active_connections)}")
43
  await self.broadcast_count()
44
 
45
+ async def disconnect(self, websocket: WebSocket, client_id: str = None):
46
  if websocket in self.active_connections:
47
  self.active_connections.remove(websocket)
48
+ if client_id and client_id in self.user_connections:
49
+ del self.user_connections[client_id]
50
+ print(f"WS Disconnected. Total: {len(self.active_connections)}")
51
+ await self.broadcast_count()
52
+
53
+ async def send_personal_message(self, message: dict, client_id: str):
54
+ if client_id in self.user_connections:
55
+ try:
56
+ await self.user_connections[client_id].send_text(json.dumps(message))
57
+ except Exception as e:
58
+ print(f"WS Send Error ({client_id}): {e}")
59
+ self.disconnect(self.user_connections[client_id], client_id)
60
 
61
  async def broadcast_count(self):
62
  count = len(self.active_connections)
 
92
  GLOBAL_LOOP = asyncio.get_running_loop()
93
 
94
  @app.websocket("/ws/stats")
95
+ async def websocket_endpoint(websocket: WebSocket, client_id: str = None):
96
+ await manager.connect(websocket, client_id)
97
  try:
98
  while True:
99
  # 接收客户端心跳包
 
102
  await websocket.send_text(json.dumps({"type": "pong"}))
103
  except WebSocketDisconnect:
104
  print(f"WebSocket disconnected normally: {id(websocket)}")
105
+ await manager.disconnect(websocket, client_id)
106
  except Exception as e:
107
  print(f"WS Error for {id(websocket)}: {e}")
108
+ await manager.disconnect(websocket, client_id)
109
 
110
  # --- 配置区域 ---
111
  # 支持多卡负载均衡:配置多个 ComfyUI 地址
 
157
 
158
  class CloudGenRequest(BaseModel):
159
  prompt: str
160
+ api_key: str
161
  resolution: str = "1024x1024"
162
+ client_id: Optional[str] = None
163
+ type: str = "default"
164
+ image_urls: List[str] = []
165
+ model: str = ""
166
 
167
  class DeleteHistoryRequest(BaseModel):
168
  timestamp: float
 
335
  return 0 # 旧数据排在最后
336
 
337
  data.sort(key=sort_key, reverse=True)
338
+
339
+ # 补充 is_cloud 字段:如果历史记录中没有标记,但文件名包含特征字符,则补充标记
340
+ for item in data:
341
+ if "is_cloud" not in item and item.get("images"):
342
+ # 检查是否有任意一张图片符合 cloud 特征
343
+ if any("cloud_angle" in img or "cloud_" in img for img in item["images"]):
344
+ item["is_cloud"] = True
345
+
346
  return data
347
  except Exception as e:
348
  print(f"读取历史文件失败: {e}")
 
459
  except: pass
460
  return {"success": True}
461
 
462
+ class CloudPollRequest(BaseModel):
463
+ task_id: str
464
+ api_key: str
465
+ client_id: Optional[str] = None
466
+
467
+ @app.post("/api/angle/poll_status")
468
+ async def poll_angle_cloud(req: CloudPollRequest):
469
+ """
470
+ Resume polling for an existing Angle task.
471
+ """
472
+ base_url = 'https://api-inference.modelscope.cn/'
473
+ clean_token = req.api_key.strip()
474
+
475
+ headers = {
476
+ "Authorization": f"Bearer {clean_token}",
477
+ "Content-Type": "application/json",
478
+ "X-ModelScope-Async-Mode": "true"
479
+ }
480
+
481
+ task_id = req.task_id
482
+ print(f"Resuming polling for Angle Task: {task_id}")
483
+
484
+ try:
485
+ async with httpx.AsyncClient(timeout=30) as client:
486
+ # Poll Status (Another 300 retries)
487
+ for i in range(300):
488
+ await asyncio.sleep(2)
489
+ try:
490
+ result = await client.get(
491
+ f"{base_url}v1/tasks/{task_id}",
492
+ headers={**headers, "X-ModelScope-Task-Type": "image_generation"},
493
+ )
494
+ data = result.json()
495
+ status = data.get("task_status")
496
+
497
+ if status == "SUCCEED":
498
+ img_url = data["output_images"][0]
499
+ print(f"Angle Task SUCCEED: {img_url}")
500
+
501
+ if req.client_id:
502
+ await manager.send_personal_message({
503
+ "type": "cloud_status",
504
+ "status": "SUCCEED",
505
+ "task_id": task_id
506
+ }, req.client_id)
507
+
508
+ # Download logic
509
+ local_path = ""
510
+ try:
511
+ async with httpx.AsyncClient() as dl_client:
512
+ img_res = await dl_client.get(img_url)
513
+ if img_res.status_code == 200:
514
+ filename = f"cloud_angle_{int(time.time())}.png"
515
+ file_path = os.path.join(OUTPUT_DIR, filename)
516
+ with open(file_path, "wb") as f:
517
+ f.write(img_res.content)
518
+ local_path = f"/output/{filename}"
519
+ else:
520
+ local_path = img_url
521
+ except Exception:
522
+ local_path = img_url
523
+
524
+ record = {
525
+ "timestamp": time.time(),
526
+ "prompt": f"Resumed {task_id}",
527
+ "images": [local_path],
528
+ "type": "angle"
529
+ }
530
+ save_to_history(record)
531
+ return {"url": local_path}
532
+
533
+ elif status == "FAILED":
534
+ if req.client_id:
535
+ await manager.send_personal_message({
536
+ "type": "cloud_status",
537
+ "status": "FAILED",
538
+ "task_id": task_id
539
+ }, req.client_id)
540
+ raise Exception(f"ModelScope task failed: {data}")
541
+
542
+ if i % 5 == 0:
543
+ print(f"Angle Task {task_id} status: {status} ({i}/150)")
544
+ if req.client_id:
545
+ await manager.send_personal_message({
546
+ "type": "cloud_status",
547
+ "status": f"{status} ({i}/150)",
548
+ "task_id": task_id,
549
+ "progress": i,
550
+ "total": 150
551
+ }, req.client_id)
552
+
553
+ except Exception as loop_e:
554
+ print(f"Angle polling error: {loop_e}")
555
+ continue
556
+
557
+ print(f"Angle Task Timeout Again: {task_id}")
558
+ if req.client_id:
559
+ await manager.send_personal_message({
560
+ "type": "cloud_status",
561
+ "status": "TIMEOUT",
562
+ "task_id": task_id
563
+ }, req.client_id)
564
+
565
+ return {"status": "timeout", "task_id": task_id, "message": "Task still pending"}
566
+
567
+ except Exception as e:
568
+ print(f"Angle polling error: {e}")
569
+ raise HTTPException(status_code=400, detail=str(e))
570
+
571
+ @app.post("/api/angle/generate")
572
+ async def generate_angle_cloud(req: CloudGenRequest):
573
+ """
574
+ Dedicated endpoint for Angle/Qwen-Image-Edit tasks.
575
+ Logic mirrors test/main.py but uses async httpx.
576
+ """
577
+ base_url = 'https://api-inference.modelscope.cn/'
578
+ clean_token = req.api_key.strip()
579
+
580
+ headers = {
581
+ "Authorization": f"Bearer {clean_token}",
582
+ "Content-Type": "application/json",
583
+ "X-ModelScope-Async-Mode": "true"
584
+ }
585
+
586
+ # Prepare payload exactly as in test/main.py
587
+ # test/main.py: "image_url": [data_uri]
588
+ # req.image_urls is already a list of strings
589
+ payload = {
590
+ "model": "Qwen/Qwen-Image-Edit-2511",
591
+ "prompt": req.prompt.strip(),
592
+ "image_url": req.image_urls
593
+ }
594
+
595
+ print(f"Angle Cloud Request: {payload['model']}, Prompt: {payload['prompt'][:20]}...")
596
+
597
+ try:
598
+ async with httpx.AsyncClient(timeout=30) as client:
599
+ # 1. Submit Task
600
+ submit_res = await client.post(
601
+ f"{base_url}v1/images/generations",
602
+ headers=headers,
603
+ json=payload # httpx handles json serialization
604
+ )
605
+
606
+ if submit_res.status_code != 200:
607
+ try:
608
+ detail = submit_res.json()
609
+ except:
610
+ detail = submit_res.text
611
+ print(f"Angle Submit Error: {detail}")
612
+ raise HTTPException(status_code=submit_res.status_code, detail=detail)
613
+
614
+ task_id = submit_res.json().get("task_id")
615
+ print(f"Angle Task Submitted, ID: {task_id}")
616
+
617
+ # Notify frontend via WS
618
+ if req.client_id:
619
+ await manager.send_personal_message({
620
+ "type": "cloud_status",
621
+ "status": "SUBMITTED",
622
+ "task_id": task_id,
623
+ "progress": 0,
624
+ "total": 150
625
+ }, req.client_id)
626
+
627
+ # 2. Poll Status (300 retries * 2s = 600s / 10min)
628
+ for i in range(300):
629
+ await asyncio.sleep(2)
630
+ try:
631
+ result = await client.get(
632
+ f"{base_url}v1/tasks/{task_id}",
633
+ headers={**headers, "X-ModelScope-Task-Type": "image_generation"},
634
+ )
635
+ data = result.json()
636
+ status = data.get("task_status")
637
+
638
+ if status == "SUCCEED":
639
+ img_url = data["output_images"][0]
640
+ print(f"Angle Task SUCCEED: {img_url}")
641
+
642
+ # Notify WS success
643
+ if req.client_id:
644
+ await manager.send_personal_message({
645
+ "type": "cloud_status",
646
+ "status": "SUCCEED",
647
+ "task_id": task_id
648
+ }, req.client_id)
649
+
650
+ # Download and Save Logic (reused from original generate)
651
+ local_path = ""
652
+ try:
653
+ # 异步下载
654
+ async with httpx.AsyncClient() as dl_client:
655
+ img_res = await dl_client.get(img_url)
656
+ if img_res.status_code == 200:
657
+ filename = f"cloud_angle_{int(time.time())}.png"
658
+ file_path = os.path.join(OUTPUT_DIR, filename)
659
+ with open(file_path, "wb") as f:
660
+ f.write(img_res.content)
661
+ local_path = f"/output/{filename}"
662
+ print(f"Angle Image saved: {local_path}")
663
+ else:
664
+ local_path = img_url
665
+ except Exception as dl_e:
666
+ print(f"Download error: {dl_e}")
667
+ local_path = img_url
668
+
669
+ # Save history
670
+ record = {
671
+ "timestamp": time.time(),
672
+ "prompt": req.prompt,
673
+ "images": [local_path],
674
+ "type": "angle", # Distinct type
675
+ "is_cloud": True
676
+ }
677
+ save_to_history(record)
678
+ return {"url": local_path}
679
+
680
+ elif status == "FAILED":
681
+ if req.client_id:
682
+ await manager.send_personal_message({
683
+ "type": "cloud_status",
684
+ "status": "FAILED",
685
+ "task_id": task_id
686
+ }, req.client_id)
687
+ raise Exception(f"ModelScope task failed: {data}")
688
+
689
+ # Log polling status every 5 times (10 seconds)
690
+ if i % 5 == 0:
691
+ print(f"Angle Task {task_id} status: {status} ({i}/150)")
692
+ if req.client_id:
693
+ await manager.send_personal_message({
694
+ "type": "cloud_status",
695
+ "status": f"{status} ({i}/150)",
696
+ "task_id": task_id,
697
+ "progress": i,
698
+ "total": 150
699
+ }, req.client_id)
700
+
701
+ except Exception as loop_e:
702
+ # Log polling errors
703
+ print(f"Angle polling error (retrying): {loop_e}")
704
+ continue
705
+
706
+ # Timeout Handling
707
+ print(f"Angle Task Timeout: {task_id}")
708
+ if req.client_id:
709
+ await manager.send_personal_message({
710
+ "type": "cloud_status",
711
+ "status": "TIMEOUT",
712
+ "task_id": task_id
713
+ }, req.client_id)
714
+
715
+ # Instead of raising Exception, return special status
716
+ return {"status": "timeout", "task_id": task_id, "message": "Task still pending"}
717
+
718
+ except Exception as e:
719
+ print(f"Angle generation error: {e}")
720
+ raise HTTPException(status_code=400, detail=str(e))
721
+
722
+
723
  @app.post("/generate")
724
  async def generate_cloud(req: CloudGenRequest):
725
  base_url = 'https://api-inference.modelscope.cn/'
 
731
  }
732
 
733
  # 按照官方 Z-Image 标准版参数
734
+ # if req.type == "angle" or req.model == "Qwen/Qwen-Image-Edit-2511":
735
+ # # Deprecated: Angle logic moved to /api/angle/generate
736
+ # pass
737
+
738
+ # 默认 Z-Image Turbo 模式
739
  payload = {
740
  "model": "Tongyi-MAI/Z-Image-Turbo",
741
  "prompt": req.prompt.strip(),
 
747
  async with httpx.AsyncClient(timeout=30) as client:
748
  # A. 提交异步任务
749
  print(f"Submitting ModelScope task for prompt: {req.prompt[:20]}...")
750
+
751
+ # Use json parameter to ensure standard serialization (matching requests behavior)
752
+ # This handles Content-Type and ensure_ascii default behavior
753
  submit_res = await client.post(
754
  f"{base_url}v1/images/generations",
755
  headers={**headers, "X-ModelScope-Async-Mode": "true"},
756
+ json=payload
757
  )
758
 
759
  if submit_res.status_code != 200:
 
769
  print(f"Task submitted, ID: {task_id}")
770
 
771
  # B. 轮询任务状态
772
+ # 增加到 200 次轮询 * 3秒 = 600秒 (10分钟) 超时
773
+ for i in range(200):
774
  await asyncio.sleep(3)
775
  try:
776
  result = await client.get(