Ezmary commited on
Commit
e256185
·
verified ·
1 Parent(s): 8a7d386

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -187
app.py CHANGED
@@ -1,12 +1,10 @@
1
- import os
2
- import shutil
3
- import logging
4
- import uuid
5
  import asyncio
 
 
 
 
6
  from fastapi import FastAPI, BackgroundTasks
7
  from pydantic import BaseModel
8
- from gradio_client import Client, handle_file
9
- import httpx
10
 
11
  # تنظیمات لاگ
12
  logging.basicConfig(level=logging.INFO)
@@ -14,12 +12,9 @@ logger = logging.getLogger(__name__)
14
 
15
  app = FastAPI()
16
 
17
- # آدرس سرور هوش مصنوعی (بدون /gradio_api)
18
- HF_MODEL_URL = "https://wan-ai-wan2-2-s2v.ms.show/"
19
-
20
- # پوشه موقت برای دانلود فایل‌ها
21
- TEMP_DIR = "/tmp/worker_files"
22
- os.makedirs(TEMP_DIR, exist_ok=True)
23
 
24
  class TaskPayload(BaseModel):
25
  job_id: str
@@ -28,212 +23,144 @@ class TaskPayload(BaseModel):
28
  resolution: str
29
  callback_url: str
30
 
31
- async def download_file(url: str, suffix: str) -> str:
32
- """فایل را از مدیر دانلود و در پوشه موقت ذخیره می‌کند"""
33
- local_filename = f"{uuid.uuid4()}.{suffix}"
34
- local_path = os.path.join(TEMP_DIR, local_filename)
35
-
36
- logger.info(f"Downloading {url} to {local_path}...")
37
-
38
- async with httpx.AsyncClient() as client:
39
- resp = await client.get(url, follow_redirects=True, timeout=60.0)
40
- if resp.status_code != 200:
41
- raise Exception(f"Download failed: {resp.status_code}")
42
-
43
- with open(local_path, "wb") as f:
44
- f.write(resp.content)
45
-
46
- return local_path
47
-
48
- def run_gradio_client(img_path, aud_path, resolution):
49
- """این تابع به صورت همگام (Sync) اجرا می‌شود چون کتابخانه Gradio همگام است"""
50
- logger.info("Connecting to Gradio Client...")
51
- client = Client(HF_MODEL_URL)
52
-
53
- logger.info("Sending request to Wan2.2 (This may take minutes)...")
54
 
55
- # ارسال درخواست و انتظار برای نتیجه (Blocking)
56
- result = client.predict(
57
- ref_img=handle_file(img_path),
58
- audio=handle_file(aud_path),
59
- resolution=resolution,
60
- api_name="/predict"
61
- )
62
 
63
- # نتیجه معمولاً یک لیست است: [مسیر_ویدیو, پیام_وضعیت]
64
- # یا گاهی یک دیکشنری بسته به نسخه
65
- logger.info(f"Raw Result received: {result}")
66
- return result
67
-
68
- async def process_task(payload: TaskPayload):
69
- logger.info(f"--- Processing Job {payload.job_id} ---")
70
- img_path = None
71
- aud_path = None
72
 
 
 
 
73
  try:
74
- # 1. دانلود فایل‌ها از سرور مدیر به فضای لوکال کارگر
75
- img_path = await download_file(payload.image_url, "png") # پسوند حدسی
76
- aud_path = await download_file(payload.audio_url, "mp3")
77
-
78
- # 2. اجرای کلاینت Gradio در یک Thread جداگانه (تا Event Loop اصلی بلاک نشود)
79
- # این خط معجزه می‌کند: صبر می‌کند تا ویدیو واقعاً ساخته شود
80
- result = await asyncio.to_thread(
81
- run_gradio_client,
82
- img_path,
83
- aud_path,
84
- payload.resolution
85
- )
86
 
87
- # 3. استخراج لینک ویدیو
88
- final_video_url = None
89
-
90
- # بررسی فرمت خروجی Gradio Client
91
- if isinstance(result, list) and len(result) > 0:
92
- # در نسخه کلاینت، فایل خروجی دانلود شده و در تمپ لوکال است
93
- # اما ما نیاز به لینک عمومی داریم.
94
- # نکته: کلاینت فایل را دانلود میکند. ما باید لینک اصلی را پیدا کنیم
95
- # یا فایل دانلود شده را جایی آپلود کنیم.
96
- # راه ساده‌تر برای این سناریو: استفاده از لینک Gradio API
97
-
98
- video_data = result[0]
99
- if isinstance(video_data, dict) and "url" in video_data:
100
- final_video_url = video_data["url"]
101
- elif isinstance(video_data, str) and os.path.exists(video_data):
102
- # اگر فایل دانلود شده است، یعنی موفقیت آمیز بوده.
103
- # اما ما نمیتوانیم فایل لوکال کارگر را به مدیر بدهیم (چون سرور جداست)
104
- # مگر اینکه کارگر آن را آپلود کند.
105
- # خوشبختانه Gradio Client معمولا آبجکت اصلی را هم برمیگرداند.
106
- pass
107
-
108
- # از آنجا که ما از طریق API کار میکنیم، فایل روی سرور ModelScope ساخته شده.
109
- # کلاینت پایتون آن را دانلود میکند.
110
- # برای اینکه لینک قابل دانلود به کاربر بدهیم، باید کمی کلک بزنیم:
111
- # چون فایل به Worker دانلود شده، ما نمیتوانیم لینک لوکال بدهیم.
112
- # راه حل: فایل تولید شده در Worker را به Manager آپلود کنیم؟ نه پیچیده است.
113
- # راه حل بهتر: استفاده از خروجی خام بدون دانلود خودکار (که سخت است).
114
-
115
- # بیایید فرض کنیم کلاینت فایل را دانلود کرده است `video_path`
116
- local_video_result = result[0] # مسیر فایل ویدیوی تولید شده روی دیسک کارگر
117
-
118
- if local_video_result and os.path.exists(local_video_result):
119
- logger.info(f"Video generated locally at: {local_video_result}")
120
-
121
- # حالا باید این ویدیو را به سرور مدیر برگردانیم
122
- # چون لینک مستقیم نداریم (فایل پرایوت است)، آن را به مدیر آپلود میکنیم
123
-
124
- async with httpx.AsyncClient(timeout=60.0) as client:
125
- with open(local_video_result, "rb") as f:
126
- files = {"file": ("generated_video.mp4", f, "video/mp4")}
127
- # فرض میکنیم مدیر یک اندپوینت برای دریافت فایل نهایی دارد
128
- # اگر ندارد، فعلا لینک تسک را میفرستیم (که با این روش کار نمیکند)
129
- pass
130
-
131
- # --- اصلاح استراتژی برای سادگی ---
132
- # چون انتقال فایل بین سرورها سخت است، ما لینک Gradio را میخواهیم.
133
- # اما gradio_client فایل را دانلود میکند.
134
- # بیایید دوباره به روش httpx برگردیم اما با استراتژی Poll (نه Stream)
135
- raise Exception("Switching logic") # Jump to except block to retry logic? No.
136
-
137
  except Exception as e:
138
- pass
139
 
140
- # ==================================================================
141
- # روش دوم و نهایی (ترکیبی):
142
- # استفاده از gradio_client فقط برای آپلود و predict، اما گرفتن لینک وب
143
- # ==================================================================
144
 
145
- try:
146
- # تلاش مجدد با کلاینت اما استخراج اطلاعات متا
147
- # متاسفانه gradio_client فایل را دانلود میکند و لینک اصلی را مخفی میکند.
148
- # پس برمیگردیم به httpx اما بدون stream.
149
- # به جای stream از polling استفاده میکنیم که قطع نمیشود.
150
-
151
- async with httpx.AsyncClient(timeout=300.0) as client:
152
- # 1. آپلود (همان کد قبلی که کار میکرد)
153
- logger.info("Uploading files manually to ensure we get URL...")
154
-
155
- # دانلود فایل از مدیر
156
- f_img = await client.get(payload.image_url)
157
- f_aud = await client.get(payload.audio_url)
158
 
159
- # آپلود به Gradio
160
- up_img = await client.post(f"{WAN_API_BASE}/upload", files={'files': ('img.png', f_img.content)})
161
- up_aud = await client.post(f"{WAN_API_BASE}/upload", files={'files': ('aud.mp3', f_aud.content)})
162
-
163
- remote_img = up_img.json()[0]
164
- remote_aud = up_aud.json()[0]
165
-
166
- # 2. Submit Request
167
  req_data = {
168
  "data": [
169
- {"path": remote_img, "meta": {"_type": "gradio.FileData"}},
170
- {"path": remote_aud, "meta": {"_type": "gradio.FileData"}},
171
  payload.resolution
172
  ]
173
  }
174
 
175
- resp = await client.post(f"{WAN_API_BASE}/call/predict", json=req_data)
176
- event_id = resp.json()['event_id']
177
- logger.info(f"Event ID: {event_id} - Polling started")
 
178
 
179
- # 3. POLLING (به جای Streaming)
180
- # هر 3 ثانیه وضعیت را چک میکنیم
181
- final_url = None
182
- for i in range(100): # تا 300 ثانیه (5 دقیقه) تلاش کن
183
- await asyncio.sleep(3)
184
-
185
- # در Gradio 5 برای گرفتن وضعیت باید به استریم وصل شد، اما اگر قطع شد مهم نیست
186
- # راه حل جایگزین: استفاده از requests معمولی (نه stream) به صورت تک‌ضرب
187
- # متاسفانه API استاندارد polling ندارد.
188
-
189
- # بیایید دوباره Stream را تست کنیم اما با تنظیمات بسیار خاص که قطع نشود
190
  try:
191
- async with client.stream("GET", f"{WAN_API_BASE}/call/predict/{event_id}", headers={"Accept": "text/event-stream"}, timeout=10.0) as stream_resp:
192
- async for line in stream_resp.aiter_lines():
 
 
 
 
193
  if line.startswith("data: "):
194
  try:
195
  data = json.loads(line[6:])
196
- # لاگ کردن دیتا برای فهمیدن وضعیت
197
- logger.info(f"Poll Status: {str(data)[:50]}...")
198
 
 
 
 
 
199
  if isinstance(data, list) and len(data) > 0:
200
- # چک کردن موفقیت
201
- if data[0] and "video" in str(data[0]):
202
- # استخراج لینک
203
- res = data[0]
204
- if isinstance(res, dict) and "video" in res:
205
- final_url = res["video"]["url"]
206
- elif isinstance(res, dict) and "url" in res:
207
- final_url = res["url"]
 
 
 
208
 
209
- if final_url:
210
- break # Loop line
211
- except:
212
  pass
213
- if final_url: break # Loop poll
214
- except Exception as stream_err:
215
- logger.warning(f"Stream checking interrupted, retrying... {stream_err}")
216
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
- if final_url:
219
- if final_url.startswith("/"):
220
- final_url = f"https://wan-ai-wan2-2-s2v.ms.show{final_url}"
221
 
222
- logger.info(f"✅ SUCCESS! Video: {final_url}")
223
  await client.post(payload.callback_url, json={
224
- "job_id": payload.job_id, "status": "COMPLETED", "video_url": final_url
 
 
225
  })
226
  else:
227
- raise Exception("Timed out waiting for video.")
228
 
229
- except Exception as e:
230
- logger.error(f"FAILED: {e}")
231
- async with httpx.AsyncClient() as client:
232
- await client.post(payload.callback_url, json={
233
- "job_id": payload.job_id, "status": "FAILED", "message": str(e)
234
- })
 
 
 
 
 
235
 
236
  @app.post("/process")
237
  async def accept_task(payload: TaskPayload, background_tasks: BackgroundTasks):
 
238
  background_tasks.add_task(process_task, payload)
239
- return {"status": "accepted"}
 
 
 
 
 
 
 
 
 
1
  import asyncio
2
+ import json
3
+ import httpx
4
+ import logging
5
+ import time
6
  from fastapi import FastAPI, BackgroundTasks
7
  from pydantic import BaseModel
 
 
8
 
9
  # تنظیمات لاگ
10
  logging.basicConfig(level=logging.INFO)
 
12
 
13
  app = FastAPI()
14
 
15
+ # --- تنظیمات حیاتی ---
16
+ # این متغیر که قبلاً خطا می‌داد الان تعریف شده است
17
+ WAN_API_BASE = "https://wan-ai-wan2-2-s2v.ms.show/gradio_api"
 
 
 
18
 
19
  class TaskPayload(BaseModel):
20
  job_id: str
 
23
  resolution: str
24
  callback_url: str
25
 
26
+ async def upload_to_wan(client: httpx.AsyncClient, file_url: str):
27
+ """فایل را از مدیر دانلود و مستقیم به سرور هوش مصنوعی آپلود می‌کند"""
28
+ filename = file_url.split("/")[-1]
29
+ logger.info(f"Downloading {filename}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ # دانلود از سرور خودمان
32
+ resp = await client.get(file_url, timeout=60.0)
33
+ if resp.status_code != 200:
34
+ raise Exception(f"Download failed: {resp.status_code}")
 
 
 
35
 
36
+ file_bytes = resp.content
 
 
 
 
 
 
 
 
37
 
38
+ # آپلود به سرور Wan
39
+ logger.info(f"Uploading {filename} to External AI Server...")
40
+ files = {'files': (filename, file_bytes)}
41
  try:
42
+ # تلاش برای آپلود (ممکن است چند بار نیاز باشد)
43
+ wan_resp = await client.post(f"{WAN_API_BASE}/upload", files=files, timeout=60.0)
44
+ if wan_resp.status_code != 200:
45
+ raise Exception(f"Upload failed: {wan_resp.text}")
 
 
 
 
 
 
 
 
46
 
47
+ remote_path = wan_resp.json()[0]
48
+ return remote_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  except Exception as e:
50
+ raise Exception(f"Upload connection error: {e}")
51
 
52
+ async def process_task(payload: TaskPayload):
53
+ logger.info(f"=== Starting Job {payload.job_id} ===")
 
 
54
 
55
+ async with httpx.AsyncClient(timeout=None) as client: # تایم‌اوت کلی غیرفعال
56
+ try:
57
+ # 1. آماده‌سازی فایل‌ها
58
+ img_remote = await upload_to_wan(client, payload.image_url)
59
+ aud_remote = await upload_to_wan(client, payload.audio_url)
 
 
 
 
 
 
 
 
60
 
61
+ # 2. ثبت درخواست (Predict)
 
 
 
 
 
 
 
62
  req_data = {
63
  "data": [
64
+ {"path": img_remote, "meta": {"_type": "gradio.FileData"}},
65
+ {"path": aud_remote, "meta": {"_type": "gradio.FileData"}},
66
  payload.resolution
67
  ]
68
  }
69
 
70
+ logger.info("Submitting prediction task...")
71
+ predict_resp = await client.post(f"{WAN_API_BASE}/call/predict", json=req_data, timeout=30.0)
72
+ if predict_resp.status_code != 200:
73
+ raise Exception(f"Prediction failed: {predict_resp.text}")
74
 
75
+ event_id = predict_resp.json().get("event_id")
76
+ logger.info(f"Job Queued! Event ID: {event_id}")
77
+
78
+ # 3. حلقه انتظار هوشمند (Robust Polling)
79
+ # این حلقه تا 10 دقیقه (600 ثانیه) تلاش می‌کند
80
+ start_time = time.time()
81
+ final_video_url = None
82
+
83
+ while time.time() - start_time < 600:
 
 
84
  try:
85
+ # اتصال به استریم
86
+ stream_url = f"{WAN_API_BASE}/call/predict/{event_id}"
87
+ async with client.stream("GET", stream_url, headers={"Accept": "text/event-stream"}, timeout=20.0) as response:
88
+ async for line in response.aiter_lines():
89
+ if not line.strip(): continue
90
+
91
  if line.startswith("data: "):
92
  try:
93
  data = json.loads(line[6:])
 
 
94
 
95
+ # لاگ وضعیت (اختیاری)
96
+ # logger.info(f"Status check: {str(data)[:50]}")
97
+
98
+ # بررسی وجود ویدیو در پاسخ
99
  if isinstance(data, list) and len(data) > 0:
100
+ result = data[0]
101
+ # الگوهای مختلف پیدا کردن لینک
102
+ found_url = None
103
+ if isinstance(result, dict):
104
+ found_url = result.get("video", {}).get("url") or result.get("url") or (f"/file={result['name']}" if "name" in result else None)
105
+ elif isinstance(result, str) and ("/file=" in result or ".mp4" in result):
106
+ found_url = result
107
+
108
+ if found_url:
109
+ final_video_url = found_url
110
+ break # شکستن حلقه خواندن خط
111
 
112
+ except Exception:
 
 
113
  pass
114
+
115
+ if final_video_url:
116
+ break # شکستن حلقه اصلی زمان
117
+
118
+ # اگر استریم قطع شد ولی ویدیو هنوز آماده نیست، کمی صبر کن و دوباره وصل شو
119
+ logger.info("Stream disconnected or empty, reconnecting to check status...")
120
+ await asyncio.sleep(3)
121
+
122
+ except Exception as e:
123
+ logger.warning(f"Connection glitch ({e}), retrying in 3 seconds...")
124
+ await asyncio.sleep(3)
125
+
126
+ # 4. پایان کار
127
+ if final_video_url:
128
+ # اصلاح لینک
129
+ if final_video_url.startswith("/"):
130
+ final_video_url = f"https://wan-ai-wan2-2-s2v.ms.show{final_video_url}"
131
+
132
+ # فیکس کردن باگ دو اسلش
133
+ final_video_url = final_video_url.replace("//file=", "/file=")
134
 
135
+ logger.info(f"✅ Success! Video URL: {final_video_url}")
 
 
136
 
137
+ # ارسال به مدیر
138
  await client.post(payload.callback_url, json={
139
+ "job_id": payload.job_id,
140
+ "status": "COMPLETED",
141
+ "video_url": final_video_url
142
  })
143
  else:
144
+ raise Exception("Timeout: Video processing took too long (>10 mins)")
145
 
146
+ except Exception as e:
147
+ logger.error(f"❌ Job Failed: {e}")
148
+ # تلاش برای اعلام خطا به مدیر
149
+ try:
150
+ await client.post(payload.callback_url, json={
151
+ "job_id": payload.job_id,
152
+ "status": "FAILED",
153
+ "message": str(e)
154
+ })
155
+ except:
156
+ pass
157
 
158
  @app.post("/process")
159
  async def accept_task(payload: TaskPayload, background_tasks: BackgroundTasks):
160
+ # تسک را می‌پذیریم و در پس‌زمینه اجرا می‌کنیم
161
  background_tasks.add_task(process_task, payload)
162
+ return {"status": "accepted"}
163
+
164
+ @app.get("/")
165
+ async def root():
166
+ return {"status": "Worker Ready", "api_target": WAN_API_BASE}