Files changed (1) hide show
  1. app.py +10 -65
app.py CHANGED
@@ -8,68 +8,13 @@ import time
8
  import gradio as gr
9
  import requests
10
 
11
- from pathlib import Path
12
- from datetime import datetime, timedelta
13
-
14
  import dashscope
15
- # from dashscope.utils.oss_utils import check_and_upload_local
16
 
17
  DASHSCOPE_API_KEY = os.getenv("DASHSCOPE_API_KEY")
18
  dashscope.api_key = DASHSCOPE_API_KEY
19
 
20
 
21
-
22
- def get_upload_policy(api_key, model_name):
23
- """获取文件上传凭证"""
24
- url = "https://dashscope.aliyuncs.com/api/v1/uploads"
25
- headers = {
26
- "Authorization": f"Bearer {api_key}",
27
- "Content-Type": "application/json"
28
- }
29
- params = {
30
- "action": "getPolicy",
31
- "model": model_name
32
- }
33
-
34
- response = requests.get(url, headers=headers, params=params)
35
- if response.status_code != 200:
36
- raise Exception(f"Failed to get upload policy: {response.text}")
37
-
38
- return response.json()['data']
39
-
40
- def upload_file_to_oss(policy_data, file_path):
41
- """将文件上传到临时存储OSS"""
42
- file_name = Path(file_path).name
43
- key = f"{policy_data['upload_dir']}/{file_name}"
44
-
45
- with open(file_path, 'rb') as file:
46
- files = {
47
- 'OSSAccessKeyId': (None, policy_data['oss_access_key_id']),
48
- 'Signature': (None, policy_data['signature']),
49
- 'policy': (None, policy_data['policy']),
50
- 'x-oss-object-acl': (None, policy_data['x_oss_object_acl']),
51
- 'x-oss-forbid-overwrite': (None, policy_data['x_oss_forbid_overwrite']),
52
- 'key': (None, key),
53
- 'success_action_status': (None, '200'),
54
- 'file': (file_name, file)
55
- }
56
-
57
- response = requests.post(policy_data['upload_host'], files=files)
58
- if response.status_code != 200:
59
- raise Exception(f"Failed to upload file: {response.text}")
60
-
61
- return f"oss://{key}"
62
-
63
- def upload_file_and_get_url(api_key, model_name, file_path):
64
- """上传文件并获取URL"""
65
- # 1. 获取上传凭证,上传凭证接口有限流,超出限流将导致请求失败
66
- policy_data = get_upload_policy(api_key, model_name)
67
- # 2. 上传文件到OSS
68
- oss_url = upload_file_to_oss(policy_data, file_path)
69
-
70
- return oss_url
71
-
72
-
73
  class WanAnimateApp:
74
  def __init__(self, url, get_url):
75
  self.url = url
@@ -83,8 +28,8 @@ class WanAnimateApp:
83
  model,
84
  ):
85
  # Upload files to OSS if needed and get URLs
86
- image_url = upload_file_and_get_url(DASHSCOPE_API_KEY, model_id, ref_img)
87
- video_url = upload_file_and_get_url(DASHSCOPE_API_KEY, model_id, video)
88
 
89
  # Prepare the request payload
90
  payload = {
@@ -109,7 +54,7 @@ class WanAnimateApp:
109
 
110
  # Make the initial API request
111
  url = self.url
112
- response = requests.post(url, json=payload, headers=headers, timeout=60)
113
 
114
  # Check if request was successful
115
  if response.status_code != 200:
@@ -129,7 +74,7 @@ class WanAnimateApp:
129
  }
130
 
131
  while True:
132
- response = requests.get(get_url, headers=headers, timeout=60)
133
  if response.status_code != 200:
134
  raise Exception(f"Failed to get task status: {response.status_code}: {response.text}")
135
 
@@ -141,16 +86,16 @@ class WanAnimateApp:
141
  # Task completed successfully, return video URL
142
  video_url = result["output"]["results"]["video_url"]
143
  return video_url, "SUCCEEDED"
144
- elif task_status == "PENDING" or task_status == "RUNNING":
145
- # Task is still running, wait and retry
146
- time.sleep(10) # Wait 10 seconds before polling again
147
- else:
148
- # Task failed or unknown, raise an exception with error message
149
  error_msg = result.get("output", {}).get("message", "Unknown error")
150
  code_msg = result.get("output", {}).get("code", "Unknown code")
151
  print(f"\n\nTask failed: {error_msg} Code: {code_msg} TaskId: {task_id}\n\n")
152
  return None, f"Task failed: {error_msg} Code: {code_msg} TaskId: {task_id}"
153
  # raise Exception(f"Task failed: {error_msg} TaskId: {task_id}")
 
 
 
154
 
155
  def start_app():
156
  import argparse
 
8
  import gradio as gr
9
  import requests
10
 
 
 
 
11
  import dashscope
12
+ from dashscope.utils.oss_utils import check_and_upload_local
13
 
14
  DASHSCOPE_API_KEY = os.getenv("DASHSCOPE_API_KEY")
15
  dashscope.api_key = DASHSCOPE_API_KEY
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  class WanAnimateApp:
19
  def __init__(self, url, get_url):
20
  self.url = url
 
28
  model,
29
  ):
30
  # Upload files to OSS if needed and get URLs
31
+ _, image_url = check_and_upload_local(model_id, ref_img, DASHSCOPE_API_KEY)
32
+ _, video_url = check_and_upload_local(model_id, video, DASHSCOPE_API_KEY)
33
 
34
  # Prepare the request payload
35
  payload = {
 
54
 
55
  # Make the initial API request
56
  url = self.url
57
+ response = requests.post(url, json=payload, headers=headers)
58
 
59
  # Check if request was successful
60
  if response.status_code != 200:
 
74
  }
75
 
76
  while True:
77
+ response = requests.get(get_url, headers=headers)
78
  if response.status_code != 200:
79
  raise Exception(f"Failed to get task status: {response.status_code}: {response.text}")
80
 
 
86
  # Task completed successfully, return video URL
87
  video_url = result["output"]["results"]["video_url"]
88
  return video_url, "SUCCEEDED"
89
+ elif task_status == "FAILED":
90
+ # Task failed, raise an exception with error message
 
 
 
91
  error_msg = result.get("output", {}).get("message", "Unknown error")
92
  code_msg = result.get("output", {}).get("code", "Unknown code")
93
  print(f"\n\nTask failed: {error_msg} Code: {code_msg} TaskId: {task_id}\n\n")
94
  return None, f"Task failed: {error_msg} Code: {code_msg} TaskId: {task_id}"
95
  # raise Exception(f"Task failed: {error_msg} TaskId: {task_id}")
96
+ else:
97
+ # Task is still running, wait and retry
98
+ time.sleep(5) # Wait 5 seconds before polling again
99
 
100
  def start_app():
101
  import argparse